diff --git a/tools/bazel.rc b/.bazelrc similarity index 88% rename from tools/bazel.rc rename to .bazelrc index 1fdf51f53e29c7111cf89c016400b710051cf9c6..17285afdb381018d0054e771475327b1f7ed9866 100644 --- a/tools/bazel.rc +++ b/.bazelrc @@ -25,12 +25,14 @@ build --define framework_shared_object=true # If you would like to use a local MKL instead of downloading, please set the # environment variable "TF_MKL_ROOT" every time before build. build:mkl --define=build_with_mkl=true --define=enable_mkl=true +build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl -c opt # This config option is used to enable MKL-DNN open source library only, # without depending on MKL binary version. build:mkl_open_source_only --define=build_with_mkl_dnn_only=true build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true +build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0 build:download_clang --crosstool_top=@local_config_download_clang//:toolchain build:download_clang --define=using_clang=true @@ -76,10 +78,9 @@ build:nonccl --define=no_nccl_support=true build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true -build --define=grpc_no_ares=true build --spawn_strategy=standalone -build --genrule_strategy=standalone +build --strategy=Genrule=standalone build -c opt # Other build flags. @@ -89,7 +90,21 @@ build --define=grpc_no_ares=true build:dynamic_kernels --define=dynamic_loaded_kernels=true build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS +# Build TF with C++ 17 features. +build:c++17 --cxxopt=-std=c++1z +build:c++17 --cxxopt=-stdlib=libc++ +build:c++1z --cxxopt=-std=c++1z +build:c++1z --cxxopt=-stdlib=libc++ + # Default paths for TF_SYSTEM_LIBS build --define=PREFIX=/usr build --define=LIBDIR=$(PREFIX)/lib build --define=INCLUDEDIR=$(PREFIX)/include + +# Default options should come above this line + +# Options from ./configure +try-import %workspace%/.tf_configure.bazelrc + +# Put user-specific options in .bazelrc.user +try-import %workspace%/.bazelrc.user diff --git a/.gitignore b/.gitignore index 90324058600bee46af56e49028977971848a80de..e1d352c238a1b2d4febe0f5d4a30cfa0c942f7e7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ .DS_Store .ipynb_checkpoints node_modules -/.bazelrc +/.bazelrc.user /.tf_configure.bazelrc /bazel-* /bazel_pip diff --git a/CODEOWNERS b/CODEOWNERS index 54a61a4d72c40d297d90d53e223f64f813d9167d..cb3fa2312405ce44d5dfc30ea4164740f436e07e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,7 +1,7 @@ # Where component owners are known, add them here. /tenosrflow/core/debug @caisq -/tensorflow/core/nccl/ @azaks @csigg +/tensorflow/core/nccl/ @azaks2 @chsigg /tensorflow/core/platform/windows/ @mrry /tensorflow/core/platform/s3 @yongtang /tensorflow/go @asimshankar @@ -51,13 +51,13 @@ /tensorflow/contrib/pi_examples/ @maciekcc /tensorflow/contrib/quantization/ @petewarden /tensorflow/contrib/rnn/ @ebrevdo @scottzhu -/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl +/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie /tensorflow/contrib/seq2seq/ @ebrevdo @lmthang /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh /tensorflow/contrib/slim/ @sguada @thenbasilmanran /tensorflow/contrib/stateless/ @girving @alextp /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank -/tensorflow/contrib/tensorrt/ @aaroey +/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2 # NEED OWNER: /tensorflow/contrib/testing/ /tensorflow/contrib/timeseries/ @allenlavoie /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4a296f265f7b9521c46d350cec26ff199f43eb6c..b978f89f9e1d79dd4f7481711a59c2b94e8bf01b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -150,41 +150,45 @@ may exist in your changes. There are two ways to run TensorFlow unit tests. -1. Using tools and libraries installed directly on your system. +1. Using tools and libraries installed directly on your system. - Refer to the - [CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) and - [GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu) - for the required packages. Alternatively, use the said - [Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g., - `tensorflow/tensorflow:nightly-devel` and `tensorflow/tensorflow:nightly-devel-gpu` - for development to avoid installing the packages directly on your system. + Refer to the + [CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) + and + [GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu) + for the required packages. Alternatively, use the said + [Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g., + `tensorflow/tensorflow:nightly-devel` and + `tensorflow/tensorflow:nightly-devel-gpu` for development to avoid + installing the packages directly on your system (in which case remember to + change directory from `/root` to `/tensorflow` once you get into the running + container so `bazel` can find the `tensorflow` workspace). - Once you have the packages installed, you can run a specific unit test in - bazel by doing as follows: + Once you have the packages installed, you can run a specific unit test in + bazel by doing as follows: - If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add - the `cuda` option flag + If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add + the `cuda` option flag - ```bash - export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" + ```bash + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" - export flags="--config=opt --config=cuda -k" - ``` + export flags="--config=opt --config=cuda -k" + ``` - For example, to run all tests under tensorflow/python, do: + For example, to run all tests under tensorflow/python, do: - ```bash - bazel test ${flags} //tensorflow/python/... - ``` + ```bash + bazel test ${flags} //tensorflow/python/... + ``` -2. Using [Docker](https://www.docker.com) and TensorFlow's CI scripts. +2. Using [Docker](https://www.docker.com) and TensorFlow's CI scripts. - ```bash - # Install Docker first, then this will build and run cpu tests - tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... - ``` - - See - [TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) for details. + ```bash + # Install Docker first, then this will build and run cpu tests + tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/... + ``` + See + [TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) + for details. diff --git a/README.md b/README.md index 044174947a094d43a51f7140dd40ec0f17801d40..96a8ecf4f693d5634da63f4ecc6f4e9c35751f5b 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,8 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. -TensorFlow provides stable Python API and C APIs as well as without API backwards compatibility guarantee like C++, Go, Java, JavaScript and Swift. +TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards +compatible API's for C++, Go, Java, JavaScript and Swift. Keep up to date with release announcements and security updates by subscribing to @@ -57,21 +58,24 @@ Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean environment to install the nightly TensorFlow build. We support CPU and GPU packages on Linux, Mac, and Windows. - #### *Try your first TensorFlow program* + ```shell $ python ``` + ```python >>> import tensorflow as tf >>> tf.enable_eager_execution() ->>> tf.add(1, 2) +>>> tf.add(1, 2).numpy() 3 >>> hello = tf.constant('Hello, TensorFlow!') >>> hello.numpy() 'Hello, TensorFlow!' ``` -Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). + +Learn more examples about how to do specific tasks in TensorFlow at the +[tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). ## Contribution guidelines @@ -113,11 +117,12 @@ The TensorFlow project strives to abide by generally accepted best practices in Build Type | Status | Artifacts ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA -**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA -**IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) -**IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) +**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) +**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) +**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) +**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) -**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)
[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)
[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)
[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl) +**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.12.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp27-cp27mu-linux_x86_64.whl)
[1.12.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp34-cp34m-linux_x86_64.whl)
[1.12.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp35-cp35m-linux_x86_64.whl)
[1.12.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.12.0-cp36-cp36m-linux_x86_64.whl) ## For more information diff --git a/RELEASE.md b/RELEASE.md index b13b071bd6cf4d3a260c8e248a67d23e1a688498..0a56e6909870e398c9d6349576cd2f8e6734f072 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -7,6 +7,8 @@ Serving. * Keras models now support evaluating with a `tf.data.Dataset`. * TensorFlow binaries are built with XLA support linked in by default. +* Ignite Dataset added to contrib/ignite that allows to work with Apache + Ignite. ## Bug Fixes and Other Changes @@ -280,50 +282,76 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A ## Bug Fixes and Other Changes -* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. -* Layered variable names have changed in the following conditions: - * Using `tf.keras.layers` with custom variable scopes. - * Using `tf.layers` in a subclassed `tf.keras.Model` class. See - [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details -* `tf.data`: - * `Dataset.from_generator()` now accepts an `args` list, in order to create nested generators. - * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. - * `tf.contrib.data.sample_from_datasets()` and `tf.contrib.data.choose_from_datasets()` make it easier to sample or deterministically choose elements from multiple datasets. - * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings, and two infrequently used arguments removed. - * (C++) `DatasetBase::DebugString()` is now `const`. - * (C++) `DatasetBase::MakeIterator()` has been renamed to `DatasetBase::MakeIteratorInternal()`. - * (C++) `IteratorBase::Initialize()` method was added to support raising errors during iterator construction. -* Eager Execution: - * Added the ability to pause recording operations for gradient computation via `tf.GradientTape.stop_recording`. - * Updated documentation, introductory notebooks. -* `tf.keras`: - * Move Keras code out of _impl folder and remove API files. - * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. - * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. -* TensorFlow Debugger (tfdbg) CLI: fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). -* `tf.contrib`: - * `tf.contrib.framework.zero_initializer` supports ResourceVariable. - * Adding "constrained_optimization" to tensorflow/contrib. -* Other: - * Add GCS Configuration Ops. - * Changing signature of `MakeIterator` to enable propagating error status. - * KL divergence for two Dirichlet distributions. - * More consistent GcsFileSystem behavior for certain reads past EOF. - * Update benchmark for tf.scan to match ranges across eager and graph modes. - * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. - * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). - * Benchmark for tf.scan in graph and eager modes. - * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. - * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce RPC calls for looking up the embeddings when there are repeated ids in the batch. - * Support indicator column in boosted trees. - * Prevent `tf.gradients()` from backpropagating through integer tensors. - * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. - * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary. - * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints. - * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product. - * Allow LinearOperator to broadcast. - * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other. - +* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. +* Layered variable names have changed in the following conditions: + * Using `tf.keras.layers` with custom variable scopes. + * Using `tf.layers` in a subclassed `tf.keras.Model` class. See + [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) + for more details +* `tf.data`: + * `Dataset.from_generator()` now accepts an `args` list, in order to + create nested generators. + * `Dataset.list_files()` now produces deterministic results when + `shuffle=False` or a `seed` is passed. + * `tf.contrib.data.sample_from_datasets()` and + `tf.contrib.data.choose_from_datasets()` make it easier to sample or + deterministically choose elements from multiple datasets. + * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted + strings, and two infrequently used arguments removed. + * (C++) `DatasetBase::DebugString()` is now `const`. + * (C++) `DatasetBase::MakeIterator()` has been renamed to + `DatasetBase::MakeIteratorInternal()`. + * (C++) `IteratorBase::Initialize()` method was added to support raising + errors during iterator construction. +* Eager Execution: + * Added the ability to pause recording operations for gradient computation + via `tf.GradientTape.stop_recording`. + * Updated documentation, introductory notebooks. +* `tf.keras`: + * Move Keras code out of _impl folder and remove API files. + * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. + * Enable dataset iterators to be passed to `tf.keras.Model` training/eval + methods. +* TensorFlow Debugger (tfdbg) CLI: fix an issue in which the TensorBoard + Debugger Plugin could not handle total source file size exceeding gRPC + message size limit (4 MB). +* `tf.contrib`: + * `tf.contrib.framework.zero_initializer` supports ResourceVariable. + * Adding "constrained_optimization" to tensorflow/contrib. +* Other: + * Add GCS Configuration Ops. + * Changing signature of `MakeIterator` to enable propagating error status. + * KL divergence for two Dirichlet distributions. + * More consistent GcsFileSystem behavior for certain reads past EOF. + * Update benchmark for tf.scan to match ranges across eager and graph + modes. + * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. + * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), + which would previously raise an error. This will correspond to an + attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only + be accessed indirectly (e.g. through getattr and setattr). To set this + up the user will first need to explicitly add the variable to the hparam + object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). + * Benchmark for tf.scan in graph and eager modes. + * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. + * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce + RPC calls for looking up the embeddings when there are repeated ids in + the batch. + * Support indicator column in boosted trees. + * Prevent `tf.gradients()` from backpropagating through integer tensors. + * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. + * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports + arbitrary. + * Added `tf.train.Checkpoint` for reading/writing object-based + checkpoints. + * Added LinearOperatorKronecker, a dense-free implementation of the + Kronecker Product. + * Allow LinearOperator to broadcast. + * SavedModelBuilder will now deduplicate asset names that point to files + with the same basename and the same contents. Note that this may result + in new asset files included in SavedModels in cases where assets with + the same name but different contents were previously overwriting each + other. ## Thanks to our Contributors @@ -821,7 +849,7 @@ answered questions, and were part of inspiring discussions. * Remove `tf.contrib.data.Iterator.from_dataset()` method. Use `Dataset.make_initializable_iterator()` instead. * Remove seldom used and unnecessary `tf.contrib.data.Iterator.dispose_op()`. -* Reorder some TFGAN loss functions in a non-backwards compatible way. +* Reorder some TF-GAN loss functions in a non-backwards compatible way. ## Known Issues * In Python 3, `Dataset.from_generator()` does not support Unicode strings. diff --git a/WORKSPACE b/WORKSPACE index 7cc08e0164a202581ad7ebbe107a9e19410e70e4..957b8d8528dc9b5e2ea134921b28601aa6fed2d1 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,14 +1,14 @@ workspace(name = "org_tensorflow") -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") http_archive( name = "io_bazel_rules_closure", - sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", - strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1", + sha256 = "43c9b882fa921923bcba764453f4058d102bece35a37c9f6383c713004aacff1", + strip_prefix = "rules_closure-9889e2348259a5aad7e805547c1a0cf311cfcd91", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", - "https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13 + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/9889e2348259a5aad7e805547c1a0cf311cfcd91.tar.gz", # 2018-12-21 ], ) @@ -16,38 +16,64 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() -http_archive( - name = "base_images_docker", - sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9", - strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6", - urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"], -) +load("//third_party/toolchains/preconfig/generate:archives.bzl", + "bazel_toolchains_archive") -http_archive( - name = "bazel_toolchains", - sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb", - strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b", - urls = [ - "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz", - ], +bazel_toolchains_archive() + +load( + "@bazel_toolchains//repositories:repositories.bzl", + bazel_toolchains_repositories = "repositories", ) -http_archive( - name = "io_bazel_rules_docker", - sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd", - strip_prefix = "rules_docker-0.5.1", - urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"], +bazel_toolchains_repositories() + +load( + "@io_bazel_rules_docker//container:container.bzl", + container_repositories = "repositories", ) -load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace") +container_repositories() + +load("//third_party/toolchains/preconfig/generate:workspace.bzl", + "remote_config_workspace") remote_config_workspace() +# Apple and Swift rules. +http_archive( + name = "build_bazel_rules_apple", + sha256 = "4fe4ee824200b48821730f89ff260984332dc3551db587c24691235d1d96a8a7", + strip_prefix = "rules_apple-0.10.0", + urls = ["https://github.com/bazelbuild/rules_apple/archive/0.10.0.tar.gz"], +) +http_archive( + name = "build_bazel_rules_swift", + sha256 = "6544ff5615febec0342de1127144d2f3e43ea80fb7f9b1ade65e6a184e39e618", + strip_prefix = "rules_swift-0.5.0", + urls = ["https://github.com/bazelbuild/rules_swift/archive/0.5.0.tar.gz"], +) +http_archive( + name = "bazel_skylib", + sha256 = "eb5c57e4c12e68c0c20bc774bfbc60a568e800d025557bc4ea022c6479acc867", + strip_prefix = "bazel-skylib-0.6.0", + urls = ["https://github.com/bazelbuild/bazel-skylib/archive/0.6.0.tar.gz"], +) +http_file( + name = "xctestrunner", + executable = 1, + urls = ["https://github.com/google/xctestrunner/releases/download/0.2.5/ios_test_runner.par"], +) +load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") +apple_rules_dependencies(ignore_version_differences = True) +load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") +swift_rules_dependencies() + # We must check the bazel version before trying to parse any other BUILD # files, in case the parsing of those build files depends on the bazel # version we require here. load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") -check_bazel_version_at_least("0.15.0") +check_bazel_version_at_least("0.19.0") load("//tensorflow:workspace.bzl", "tf_workspace") diff --git a/configure.py b/configure.py index 57a03bd17fac1a3a9942bdacf4661d021a62bbaa..8dcd31822000820df12c7e96f5c57c68ed605f41 100644 --- a/configure.py +++ b/configure.py @@ -33,7 +33,7 @@ except ImportError: from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top -_DEFAULT_CUDA_VERSION = '9.0' +_DEFAULT_CUDA_VERSION = '10.0' _DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _DEFAULT_CUDA_PATH = '/usr/local/cuda' @@ -255,18 +255,6 @@ def setup_python(environ_cp): def reset_tf_configure_bazelrc(): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() - bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc') - - data = [] - if os.path.exists(bazelrc_path): - with open(bazelrc_path, 'r') as f: - data = f.read().splitlines() - with open(bazelrc_path, 'w') as f: - for l in data: - if _TF_BAZELRC_FILENAME in l: - continue - f.write('%s\n' % l) - f.write('import %%workspace%%/%s\n' % _TF_BAZELRC_FILENAME) def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. @@ -488,11 +476,14 @@ def check_bazel_version(min_version, max_version): if curr_version_int < min_version_int: print('Please upgrade your bazel installation to version %s or higher to ' 'build TensorFlow!' % min_version) - sys.exit(0) - if curr_version_int > max_version_int: + sys.exit(1) + if (curr_version_int > max_version_int and + 'TF_IGNORE_MAX_BAZEL_VERSION' not in os.environ): print('Please downgrade your bazel installation to version %s or lower to ' - 'build TensorFlow!' % min_version) - sys.exit(0) + 'build TensorFlow! To downgrade: download the installer for the old ' + 'version (from https://github.com/bazelbuild/bazel/releases) then ' + 'run the installer.' % max_version) + sys.exit(1) return curr_version @@ -1491,7 +1482,7 @@ def set_other_mpi_vars(environ_cp): else: raise ValueError( 'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' % - mpi_home, mpi_home, mpi_home) + (mpi_home, mpi_home, mpi_home)) def set_system_libs_flag(environ_cp): @@ -1565,11 +1556,9 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.15.0', '0.19.2') + check_bazel_version('0.19.0', '0.21.0') reset_tf_configure_bazelrc() - # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later - write_to_bazelrc('import %workspace%/tools/bazel.rc') cleanup_makefile() setup_python(environ_cp) @@ -1705,7 +1694,7 @@ def main(): config_info_line('noaws', 'Disable AWS S3 filesystem support.') config_info_line('nogcp', 'Disable GCP support.') config_info_line('nohdfs', 'Disable HDFS support.') - config_info_line('noignite', 'Disable Apacha Ignite support.') + config_info_line('noignite', 'Disable Apache Ignite support.') config_info_line('nokafka', 'Disable Apache Kafka support.') config_info_line('nonccl', 'Disable NVIDIA NCCL support.') diff --git a/tensorflow/BUILD b/tensorflow/BUILD index fd4b94202aad24a82abef8abd16431f61a8326f0..6bc8403d126a58c1eb6499ab7f224e12c6bc5aa4 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -40,12 +40,16 @@ load( # @unused TENSORFLOW_API_INIT_FILES_V2 = ( - TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) + TENSORFLOW_API_INIT_FILES + + get_compat_files(TENSORFLOW_API_INIT_FILES, 2) + + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) ) # @unused -TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT = ( - TENSORFLOW_API_INIT_FILES_V1 + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +TENSORFLOW_API_INIT_FILES_V1 = ( + TENSORFLOW_API_INIT_FILES_V1 + + get_compat_files(TENSORFLOW_API_INIT_FILES, 2) + + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) ) # Config setting used when building for products @@ -202,6 +206,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "arm", + values = {"cpu": "arm"}, + visibility = ["//visibility:public"], +) + config_setting( name = "freebsd", values = {"cpu": "freebsd"}, @@ -267,6 +277,15 @@ config_setting( visibility = ["//visibility:public"], ) +# By default, XLA GPU is compiled into tensorflow when building with +# --config=cuda even when `with_xla_support` is false. The config setting +# here allows us to override the behavior if needed. +config_setting( + name = "no_xla_deps_in_cuda", + define_values = {"no_xla_deps_in_cuda": "true"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gdr_support", define_values = {"with_gdr_support": "true"}, @@ -328,6 +347,13 @@ config_setting( }, ) +config_setting( + name = "using_rocm_hipcc", + define_values = { + "using_rocm_hipcc": "true", + }, +) + config_setting( name = "with_mpi_support", values = {"define": "with_mpi_support=true"}, @@ -355,17 +381,18 @@ config_setting( define_values = {"tf_api_version": "2"}, ) +# This flag is defined for select statements that match both +# on 'windows' and 'api_version_2'. In this case, bazel requires +# having a flag which is a superset of these two. +config_setting( + name = "windows_and_api_version_2", + define_values = {"tf_api_version": "2"}, + values = {"cpu": "x64_windows"}, +) + package_group( name = "internal", - packages = [ - "-//third_party/tensorflow/python/estimator", - "//learning/meta_rank/...", - "//tensorflow/...", - "//tensorflow_estimator/contrib/...", - "//tensorflow_fold/llgtm/...", - "//tensorflow_text/...", - "//third_party/py/tensor2tensor/...", - ], + packages = ["//tensorflow/..."], ) load( @@ -574,13 +601,20 @@ gen_api_init_files( name = "tf_python_api_gen_v1", srcs = [ "api_template_v1.__init__.py", + "compat_template.__init__.py", "compat_template_v1.__init__.py", ], api_version = 1, - compat_api_versions = [1], - compat_init_templates = ["compat_template_v1.__init__.py"], + compat_api_versions = [ + 1, + 2, + ], + compat_init_templates = [ + "compat_template_v1.__init__.py", + "compat_template.__init__.py", + ], output_dir = "_api/v1/", - output_files = TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT, + output_files = TENSORFLOW_API_INIT_FILES_V1, output_package = "tensorflow._api.v1", root_file_name = "v1.py", root_init_template = "api_template_v1.__init__.py", @@ -590,11 +624,18 @@ gen_api_init_files( name = "tf_python_api_gen_v2", srcs = [ "api_template.__init__.py", + "compat_template.__init__.py", "compat_template_v1.__init__.py", ], api_version = 2, - compat_api_versions = [1], - compat_init_templates = ["compat_template_v1.__init__.py"], + compat_api_versions = [ + 1, + 2, + ], + compat_init_templates = [ + "compat_template_v1.__init__.py", + "compat_template.__init__.py", + ], output_dir = "_api/v2/", output_files = TENSORFLOW_API_INIT_FILES_V2, output_package = "tensorflow._api.v2", @@ -606,9 +647,11 @@ py_library( name = "tensorflow_py", srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ + deps = select({ + "api_version_2": [], + "//conditions:default": ["//tensorflow/contrib:contrib_py"], + }) + [ ":tensorflow_py_no_contrib", - "//tensorflow/contrib:contrib_py", "//tensorflow/python/estimator:estimator_py", ], ) @@ -618,7 +661,11 @@ py_library( srcs = select({ "api_version_2": [":tf_python_api_gen_v2"], "//conditions:default": [":tf_python_api_gen_v1"], - }) + [":root_init_gen"], + }) + [":root_init_gen"] + [ + "//tensorflow/python/keras/api:keras_python_api_gen", + "//tensorflow/python/keras/api:keras_python_api_gen_compat_v1", + "//tensorflow/python/keras/api:keras_python_api_gen_compat_v2", + ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = ["//tensorflow/python:no_contrib"], diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index f13623b0d57d3b59bb9455a46a9fab29fee25784..a93799bfe84b0f9c4743e1ad0effd6e69ad7f3f2 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -18,25 +18,77 @@ from __future__ import absolute_import as _absolute_import from __future__ import division as _division from __future__ import print_function as _print_function +import distutils as _distutils +import inspect as _inspect import os as _os +import site as _site +import sys as _sys + +# API IMPORTS PLACEHOLDER # pylint: disable=g-bad-import-order from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, - child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) - -# API IMPORTS PLACEHOLDER + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v2.estimator')) +_current_module = _sys.modules[__name__] +if not hasattr(_current_module, 'estimator'): + _component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=( + 'tensorflow_estimator.python.estimator.api.estimator')) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v2.keras')) # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. # We're using bitwise, but there's nothing special about that. _tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable -if _tf_api_dir not in __path__: +if not hasattr(_current_module, '__path__'): + __path__ = [_tf_api_dir] +elif _tf_api_dir not in __path__: __path__.append(_tf_api_dir) -# Calls to enable and disable features. -enable_eager_execution() # pylint: disable=undefined-variable +# Enable TF2 behaviors +from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top +_compat.enable_v2_behavior() + + +# Load all plugin libraries from site-packages/tensorflow-plugins if we are +# running under pip. +# TODO(gunan): Enable setting an environment variable to define arbitrary plugin +# directories. +# TODO(gunan): Find a better location for this code snippet. +from tensorflow.python.framework import load_library as _ll +from tensorflow.python.lib.io import file_io as _fi + +# Get sitepackages directories for the python installation. +_site_packages_dirs = [] +_site_packages_dirs += [_site.USER_SITE] +_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] +if 'getsitepackages' in dir(_site): + _site_packages_dirs += _site.getsitepackages() + +if 'sysconfig' in dir(_distutils): + _site_packages_dirs += [_distutils.sysconfig.get_python_lib()] + +_site_packages_dirs = list(set(_site_packages_dirs)) + +# Find the location of this exact file. +_current_file_location = _inspect.getfile(_inspect.currentframe()) + +def _running_from_pip_package(): + return any( + _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) + +if _running_from_pip_package(): + for s in _site_packages_dirs: + # TODO(gunan): Add sanity checks to loaded modules here. + plugin_dir = _os.path.join(s, 'tensorflow-plugins') + if _fi.file_exists(plugin_dir): + _ll.load_library(plugin_dir) # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They @@ -58,4 +110,6 @@ try: del compiler except NameError: pass + + # pylint: enable=undefined-variable diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 65bdb6cb1b5e6fb0656a12b932d767aeacfccd29..eeca8f0d566a6401cb64e4fe3f0ee3c5aeb4ece2 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -18,20 +18,42 @@ from __future__ import absolute_import as _absolute_import from __future__ import division as _division from __future__ import print_function as _print_function +import distutils as _distutils +import inspect as _inspect import os as _os +import site as _site +import sys as _sys # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, - child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) - -# API IMPORTS PLACEHOLDER + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v1.estimator')) +_current_module = _sys.modules[__name__] +if not hasattr(_current_module, 'estimator'): + _component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=( + 'tensorflow_estimator.python.estimator.api.estimator')) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v1.keras')) from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top -contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +_CONTRIB_WARNING = """ +WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0. +For more information, please see: + * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md + * https://github.com/tensorflow/addons +If you depend on functionality not listed there, please file an issue. +""" +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib', + _CONTRIB_WARNING) del LazyLoader # The templated code that replaces the placeholder above sometimes # sets the __all__ variable. If it does, we have to be sure to add @@ -40,14 +62,53 @@ if '__all__' in vars(): vars()['__all__'].append('contrib') from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +# The 'app' module will be imported as part of the placeholder section above. app.flags = flags # pylint: disable=undefined-variable +# Also use 'app' module (choice is arbitrary) to derive the API directory below. +_API_MODULE = app # pylint: disable=undefined-variable + # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable -if _tf_api_dir not in __path__: +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) # pylint: disable=undefined-variable +if not hasattr(_current_module, '__path__'): + __path__ = [_tf_api_dir] +elif _tf_api_dir not in __path__: __path__.append(_tf_api_dir) +# Load all plugin libraries from site-packages/tensorflow-plugins if we are +# running under pip. +# TODO(gunan): Enable setting an environment variable to define arbitrary plugin +# directories. +# TODO(gunan): Find a better location for this code snippet. +from tensorflow.python.framework import load_library as _ll +from tensorflow.python.lib.io import file_io as _fi + +# Get sitepackages directories for the python installation. +_site_packages_dirs = [] +_site_packages_dirs += [_site.USER_SITE] +_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] +if 'getsitepackages' in dir(_site): + _site_packages_dirs += _site.getsitepackages() + +if 'sysconfig' in dir(_distutils): + _site_packages_dirs += [_distutils.sysconfig.get_python_lib()] + +_site_packages_dirs = list(set(_site_packages_dirs)) + +# Find the location of this exact file. +_current_file_location = _inspect.getfile(_inspect.currentframe()) + +def _running_from_pip_package(): + return any( + _current_file_location.startswith(dir_) for dir_ in _site_packages_dirs) + +if _running_from_pip_package(): + for s in _site_packages_dirs: + # TODO(gunan): Add sanity checks to loaded modules here. + plugin_dir = _os.path.join(s, 'tensorflow-plugins') + if _fi.file_exists(plugin_dir): + _ll.load_library(plugin_dir) # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index f653e581bf3beda9fdbf8fb7905a4f9fe170e7fb..ef52a28460062b57317b4027ab83479e5e075b5f 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -83,7 +83,7 @@ tf_cuda_library( ], "//conditions:default": [ ":c_api_internal", - "//tensorflow/cc/saved_model:loader", + "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", "//tensorflow/cc:grad_ops", @@ -123,13 +123,13 @@ tf_cuda_library( "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", "//tensorflow/compiler/jit:flags", - "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_platform", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:attr_builder", + "@com_google_absl//absl/strings", ], ) @@ -175,6 +175,32 @@ tf_cuda_library( ], ) +tf_cuda_library( + name = "env", + srcs = [ + "env.cc", + ], + hdrs = [ + "env.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:lib", + ], + "//conditions:default": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + }) + [":c_api_internal"], +) + tf_cuda_library( name = "kernels", srcs = [ @@ -188,10 +214,14 @@ tf_cuda_library( deps = select({ "//tensorflow:android": [ ":c_api", + ":c_api_internal", + ":tf_status_helper", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ ":c_api", + ":c_api_internal", + ":tf_status_helper", "//tensorflow/core:framework", ], }), @@ -219,6 +249,18 @@ tf_cuda_library( ], ) +tf_cc_test( + name = "c_test", + srcs = ["c_test.c"], + extra_copts = ["-std=c11"], + deps = [ + ":c_api", + ":c_api_experimental", + ":env", + ":kernels", + ], +) + tf_cuda_cc_test( name = "c_api_test", size = "small", @@ -247,13 +289,23 @@ tf_cuda_cc_test( "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", "//tensorflow/compiler/jit", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:bitwise_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:spectral_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/kernels:array", @@ -277,6 +329,7 @@ tf_cc_test( deps = [ ":c_api", ":c_api_experimental", + ":c_api_internal", ":c_test_util", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_test_util", @@ -330,6 +383,27 @@ tf_kernel_library( alwayslink = 1, ) +tf_cuda_cc_test( + name = "env_test", + size = "small", + srcs = ["env_test.cc"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + ":env", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cuda_cc_test( name = "kernels_test", size = "small", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index f13e8777dff164bcd8eedf46310ae846abd0c804..94d9f4a6fa2f14cb3343bdd51b7e4d61944444d0 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -136,16 +136,22 @@ const char* TF_Message(const TF_Status* s) { namespace { class TF_ManagedBuffer : public TensorBuffer { public: - void* data_; - size_t len_; - void (*deallocator_)(void* data, size_t len, void* arg); - void* deallocator_arg_; + TF_ManagedBuffer(void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg) + : TensorBuffer(data), + len_(len), + deallocator_(deallocator), + deallocator_arg_(deallocator_arg) {} + + const size_t len_; + void (*const deallocator_)(void* data, size_t len, void* arg); + void* const deallocator_arg_; ~TF_ManagedBuffer() override { - (*deallocator_)(data_, len_, deallocator_arg_); + (*deallocator_)(data(), len_, deallocator_arg_); } - void* data() const override { return data_; } size_t size() const override { return len_; } TensorBuffer* root_buffer() override { return this; } void FillAllocationDescription(AllocationDescription* proto) const override { @@ -199,8 +205,7 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, dimvec[i] = static_cast(dims[i]); } - TF_ManagedBuffer* buf = new TF_ManagedBuffer; - buf->len_ = len; + TF_ManagedBuffer* buf = nullptr; if (dtype != TF_STRING && dtype != TF_RESOURCE && tensorflow::DataTypeCanUseMemcpy(static_cast(dtype)) && reinterpret_cast(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) != @@ -212,17 +217,15 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, // // Other types have the same representation, so copy only if it is safe to // do so. - buf->data_ = allocate_tensor("TF_NewTensor", len); - std::memcpy(buf->data_, data, len); - buf->deallocator_ = deallocate_buffer; - buf->deallocator_arg_ = nullptr; + buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len, + deallocate_buffer, nullptr); + std::memcpy(buf->data(), data, len); // Free the original buffer. deallocator(data, len, deallocator_arg); } else { - buf->data_ = data; - buf->deallocator_ = deallocator; - buf->deallocator_arg_ = deallocator_arg; + buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg); } + TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf}; size_t elem_size = TF_DataTypeSize(dtype); if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) { @@ -254,6 +257,74 @@ int64_t TF_Dim(const TF_Tensor* t, int dim_index) { size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); } void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); } +int64_t TF_TensorElementCount(const TF_Tensor* t) { + int64_t result = 1; + int rank = TF_NumDims(t); + for (int dim = 0; dim < rank; ++dim) { + result *= TF_Dim(t, dim); + } + return result; +} + +// Returns the number of elements that would be present in a tensor with the +// given shape. +static int64_t ShapeNumElements(const int64_t* dims, int num_dims) { + int64_t result = 1; + for (int dim = 0; dim < num_dims; ++dim) { + result *= dims[dim]; + } + return result; +} + +static void UnrefIfNonNull(::tensorflow::TensorBuffer* buf) { + if (buf != nullptr) { + buf->Unref(); + } +} + +static void RefIfNonNull(::tensorflow::TensorBuffer* buf) { + if (buf != nullptr) { + buf->Ref(); + } +} + +void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type, + TF_Tensor* to, const int64_t* new_dims, + int num_new_dims, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + size_t in_size = TF_DataTypeSize(TF_TensorType(from)); + if (in_size == 0) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "input tensor has a zero-sized data type"); + return; + } + size_t out_size = TF_DataTypeSize(type); + if (out_size == 0) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "output tensor has a zero-sized data type"); + return; + } + + if (ShapeNumElements(new_dims, num_new_dims) * out_size != + TF_TensorElementCount(from) * in_size) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "input tensor is not compatible with output shape"); + return; + } + + tensorflow::TensorShapeProto p; + for (int i = 0; i < num_new_dims; ++i) { + p.add_dim()->set_size(new_dims[i]); + } + to->shape = tensorflow::TensorShape(p); + to->dtype = type; + if (to->buffer != from->buffer) { + UnrefIfNonNull(to->buffer); + to->buffer = from->buffer; + RefIfNonNull(to->buffer); + } +} + // -------------------------------------------------------------------------- size_t TF_StringEncode(const char* src, size_t src_len, char* dst, size_t dst_len, TF_Status* status) { @@ -477,14 +548,15 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { CHECK_EQ(nelems, 0); static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), "64-bit int types should match in size"); - return TF_NewTensor(dtype, reinterpret_cast(dims.data()), - shape.dims(), reinterpret_cast(&empty), 0, - [](void*, size_t, void*) {}, nullptr); + return TF_NewTensor( + dtype, reinterpret_cast(dims.data()), shape.dims(), + reinterpret_cast(&empty), 0, [](void*, size_t, void*) {}, nullptr); } // Non-static for testing. TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); if (!src.IsInitialized()) { status->status = FailedPrecondition( "attempt to use a tensor with an uninitialized value"); @@ -1592,18 +1664,20 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, break; \ } - LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0; - for (int i = 0; i < attr->list().s_size(); - ++i) { metadata.total_size += attr->list().s(i).size(); }); + LIST_CASE( + s, TF_ATTR_STRING, metadata.total_size = 0; + for (int i = 0; i < attr->list().s_size(); + ++i) { metadata.total_size += attr->list().s(i).size(); }); LIST_CASE(i, TF_ATTR_INT); LIST_CASE(f, TF_ATTR_FLOAT); LIST_CASE(b, TF_ATTR_BOOL); LIST_CASE(type, TF_ATTR_TYPE); - LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0; - for (int i = 0; i < attr->list().shape_size(); ++i) { - const auto& s = attr->list().shape(i); - metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); - }); + LIST_CASE( + shape, TF_ATTR_SHAPE, metadata.total_size = 0; + for (int i = 0; i < attr->list().shape_size(); ++i) { + const auto& s = attr->list().shape(i); + metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); + }); LIST_CASE(tensor, TF_ATTR_TENSOR); LIST_CASE(tensor, TF_ATTR_FUNC); #undef LIST_CASE @@ -2875,6 +2949,9 @@ const char* TF_ServerTarget(TF_Server* server) { #endif } -void TF_DeleteServer(TF_Server* server) { delete server; } - +void TF_DeleteServer(TF_Server* server) { +#ifndef __ANDROID__ + delete server; +#endif +} } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 3d56268110edbe96616201d15a69cc8c84d3115a..8031928dac4de2391f0aec46e69d61a137606e4d 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -91,7 +91,7 @@ extern "C" { // -------------------------------------------------------------------------- // TF_Version returns a string describing version information of the // TensorFlow library. TensorFlow using semantic versioning. -TF_CAPI_EXPORT extern const char* TF_Version(); +TF_CAPI_EXPORT extern const char* TF_Version(void); // -------------------------------------------------------------------------- // TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. @@ -157,7 +157,7 @@ typedef enum TF_Code { typedef struct TF_Status TF_Status; // Return a new status object. -TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(); +TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void); // Delete a previously created status object. TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*); @@ -196,7 +196,7 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len); // Useful for passing *out* a protobuf. -TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(); +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void); TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); @@ -272,6 +272,39 @@ TF_CAPI_EXPORT extern size_t TF_TensorByteSize(const TF_Tensor*); // Return a pointer to the underlying data buffer. TF_CAPI_EXPORT extern void* TF_TensorData(const TF_Tensor*); +// Returns the number of elements in the tensor. +TF_CAPI_EXPORT extern int64_t TF_TensorElementCount(const TF_Tensor* tensor); + +// Copy the internal data representation of `from` to `to`. `new_dims` and +// `num_new_dims` specify the new shape of the `to` tensor, `type` specifies its +// data type. On success, *status is set to TF_OK and the two tensors share the +// same data buffer. +// +// This call requires that the `from` tensor and the given type and shape (dims +// and num_dims) are "compatible" (i.e. they occupy the same number of bytes). +// Specifically, given from_type_size = TF_DataTypeSize(TF_TensorType(from)): +// +// ShapeElementCount(dims, num_dims) * TF_DataTypeSize(type) +// +// must equal +// +// TF_TensorElementCount(from) * from_type_size +// +// where TF_ShapeElementCount would be the number of elements in a tensor with +// the given shape. +// +// In addition, this function requires: +// * TF_DataTypeSize(TF_TensorType(from)) != 0 +// * TF_DataTypeSize(type) != 0 +// +// If any of the requirements are not met, *status is set to +// TF_INVALID_ARGUMENT. +TF_CAPI_EXPORT extern void TF_TensorBitcastFrom(const TF_Tensor* from, + TF_DataType type, TF_Tensor* to, + const int64_t* new_dims, + int num_new_dims, + TF_Status* status); + // -------------------------------------------------------------------------- // Encode the string `src` (`src_len` bytes long) into `dst` in the format // required by TF_STRING tensors. Does not write to memory more than `dst_len` @@ -305,7 +338,7 @@ TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len); typedef struct TF_SessionOptions TF_SessionOptions; // Return a new options object. -TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(); +TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void); // Set the target in TF_SessionOptions.options. // target can be empty, a single entry, or a comma separated list of entries. @@ -338,7 +371,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); typedef struct TF_Graph TF_Graph; // Return a new graph object. -TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(); +TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void); // Destroy an options object. Graph will be deleted once no more // TFSession's are referencing it. @@ -890,7 +923,8 @@ TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, // TF_GraphImportGraphDef. typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; -TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(); +TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions( + void); TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_ImportGraphDefOptions* opts); @@ -1611,7 +1645,7 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // // The data in the buffer will be the serialized OpList proto for ops registered // in this address space. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void); // TF_ApiDefMap encapsulates a collection of API definitions for an operation. // diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 3693cc85996365360253c8a94c29272a16e11e9a..a8325ce494c4f57fcd7e64b2d233ee4e6666bc4e 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" +#include "absl/strings/substitute.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api.h" @@ -66,7 +67,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { } TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, - unsigned char gpu_memory_allow_growth) { + unsigned char gpu_memory_allow_growth, + unsigned int num_cpu_devices) { tensorflow::ConfigProto config; auto* optimizer_options = config.mutable_graph_options()->mutable_optimizer_options(); @@ -87,6 +89,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, auto* gpu_options = config.mutable_gpu_options(); gpu_options->set_allow_growth(gpu_memory_allow_growth); + (*config.mutable_device_count())["CPU"] = num_cpu_devices; + // TODO(b/113217601): This is needed for EagerContext::runner_ to use a // threadpool, so that we avoid the possibility of running the runner_ in the // threadpool of GPU event mgr, as that can trigger more callbacks to be @@ -125,6 +129,14 @@ const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { return ret; } +char* TF_FunctionDebugString(TF_Function* func, size_t* len) { + const auto& debug_str = func->fdef.DebugString(); + *len = debug_str.size(); + char* ret = static_cast(malloc(*len + 1)); + memcpy(ret, debug_str.c_str(), *len + 1); + return ret; +} + // On success, returns a set of TF_Function instances from `text_proto` of // GraphDef type. These functions must be deleted by calling TF_DeleteFunction. // @@ -6530,7 +6542,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/cycle_length" + name: "ExperimentalParallelInterleaveDataset/cycle_length" op: "Const" attr { key: "dtype" @@ -6551,7 +6563,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/block_length" + name: "ExperimentalParallelInterleaveDataset/block_length" op: "Const" attr { key: "dtype" @@ -6572,7 +6584,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/sloppy" + name: "ExperimentalParallelInterleaveDataset/sloppy" op: "Const" attr { key: "dtype" @@ -6593,7 +6605,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/buffer_output_elements" + name: "ExperimentalParallelInterleaveDataset/buffer_output_elements" op: "Const" attr { key: "dtype" @@ -6614,7 +6626,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/prefetch_input_elements" + name: "ExperimentalParallelInterleaveDataset/prefetch_input_elements" op: "Const" attr { key: "dtype" @@ -6635,14 +6647,14 @@ library { } } node_def { - name: "ParallelInterleaveDataset" - op: "ParallelInterleaveDataset" + name: "ExperimentalParallelInterleaveDataset" + op: "ExperimentalParallelInterleaveDataset" input: "RepeatDataset:handle:0" - input: "ParallelInterleaveDataset/cycle_length:output:0" - input: "ParallelInterleaveDataset/block_length:output:0" - input: "ParallelInterleaveDataset/sloppy:output:0" - input: "ParallelInterleaveDataset/buffer_output_elements:output:0" - input: "ParallelInterleaveDataset/prefetch_input_elements:output:0" + input: "ExperimentalParallelInterleaveDataset/cycle_length:output:0" + input: "ExperimentalParallelInterleaveDataset/block_length:output:0" + input: "ExperimentalParallelInterleaveDataset/sloppy:output:0" + input: "ExperimentalParallelInterleaveDataset/buffer_output_elements:output:0" + input: "ExperimentalParallelInterleaveDataset/prefetch_input_elements:output:0" attr { key: "Targuments" value { @@ -6742,7 +6754,7 @@ library { node_def { name: "ShuffleDataset_2" op: "ShuffleDataset" - input: "ParallelInterleaveDataset:handle:0" + input: "ExperimentalParallelInterleaveDataset:handle:0" input: "ShuffleDataset_2/buffer_size_1:output:0" input: "ShuffleDataset_2/seed_2:output:0" input: "ShuffleDataset_2/seed2_2:output:0" @@ -8535,8 +8547,9 @@ TFE_Context* TFE_CreateContextFromSession(TF_Session* session, // Reduce GPU memory allocation, and set appropriate config options for TFE // context. - auto* config = - TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true); + auto* config = TF_CreateConfig( + /*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 10); TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); if (!status->status.ok()) { CHECK(!config); @@ -8733,6 +8746,12 @@ static void CheckOk(TF_Status* status) { void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { auto* status = TF_NewStatus(); + if (!TFE_TensorHandleIsConcrete(handle)) { + VLOG(1) << "Symbolic tensor: " << handle; + TF_DeleteStatus(status); + return; + } + TF_Tensor* t = TFE_TensorHandleResolve(handle, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -8744,6 +8763,11 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { TF_DeleteStatus(status); } +void TFE_OpPrintDebugString(TFE_Op* op) { + VLOG(1) << "TFE_OpPrintDebugString() over " << op; + LOG(INFO) << op->operation.DebugString(); +} + struct TFE_ExecuteOpNotification { TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {} tensorflow::Notification n; @@ -8886,3 +8910,217 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType dtype_arg, std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); return new TFE_TensorHandle(tensor, nullptr, nullptr); } + +namespace { +tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, + TFE_Context* ctx) { + // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the + // server object (which currently CHECK-fails) and we miss the error, instead, + // we log the error, and then return to allow the user to see the error + // message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.error_message(); \ + return _status; \ + } \ + } while (0); + + std::unique_ptr server; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server)); + + tensorflow::GrpcServer* grpc_server = + dynamic_cast(server.get()); + if (grpc_server == nullptr) { + LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( + "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); + } + + LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); + + LOG_AND_RETURN_IF_ERROR(ctx->context.StoreCollectiveOpsServer( + std::move(server), grpc_server->worker_env()->device_mgr, + grpc_server->worker_env()->collective_executor_mgr)); + + return tensorflow::Status::OK(); +#undef LOG_AND_RETURN_IF_ERROR +} +} // namespace + +// Set server_def on the context, possibly updating it. +TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, + const void* proto, + size_t proto_len, + TF_Status* status) { + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + return; + } + status->status = EnableCollectiveOps(server_def, ctx); +} + +std::string tensorflow::getTF_OutputDebugString(TF_Output node) { + return absl::Substitute("TF_Output($0, $1)", node.oper, node.index); +} + +using tensorflow::getTF_OutputDebugString; + +TFE_TensorHandle* TFE_NewTensorHandleFromTFOutput(TF_Output t, + TF_DataType dtype) { + auto ret = new TFE_TensorHandle(t, dtype); + VLOG(1) << "Storing TFOutput " << getTF_OutputDebugString(t) + << " into tensor handle " << ret << " with internal handle " + << ret->handle; + return ret; +} + +unsigned char TFE_TensorHandleIsConcrete(TFE_TensorHandle* handle) { + assert(handle->handle != nullptr); + return handle->handle->getSymbolicTensor() == nullptr; +} + +TF_Output TFE_GetTFOutputFromTensorHandle(TFE_TensorHandle* handle, + TF_Status* status) { + if (TFE_TensorHandleIsConcrete(handle)) { + status->status = + tensorflow::errors::Internal("Not a symbolic tensor: ", handle); + return TF_Output{nullptr, -1}; + } + + auto* sym_tensor = handle->handle->getSymbolicTensor(); + CHECK(sym_tensor != nullptr); + auto ret = TF_Output{sym_tensor->oper, sym_tensor->index}; + VLOG(1) << "Retrieving " << getTF_OutputDebugString(ret) + << " from tensor handle " << handle; + CHECK_GE(sym_tensor->index, 0); + return ret; +} + +TFE_TraceContext* TFE_NewTraceContext(TF_Graph* graph) { + return new TFE_TraceContext(graph); +} + +void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx) { delete trace_ctx; } + +// If `handle` is already symbolic, return it. Otherwise map it to a new +// symbolic tensor (a PlaceHolder op) and return that. +static TF_Output getOrCreateSymbolicTensor(TFE_TraceContext* trace_ctx, + tensorflow::TensorHandle* handle, + TF_Status* status) { + VLOG(1) << "Getting symbolic tensor for input tensor handle " << handle + << ": " << handle->DebugString(); + + auto* sym_tensor = handle->getSymbolicTensor(); + if (sym_tensor != nullptr) { + auto ret = TF_Output{sym_tensor->oper, sym_tensor->index}; + VLOG(1) << "This handle is a symbolic tensor " << sym_tensor << ": " + << getTF_OutputDebugString(ret); + return ret; + } + + auto find_it = trace_ctx->input_tensor_map.find(handle); + if (find_it != trace_ctx->input_tensor_map.end()) { + VLOG(1) << "There exists a map entry from this concrete tensor to: " + << getTF_OutputDebugString(find_it->second); + return find_it->second; + } + + auto node_name = tensorflow::strings::StrCat("additional_input_", + trace_ctx->node_counter++); + VLOG(1) << "Adding a place holder node named " << node_name; + auto* desc = + TF_NewOperation(trace_ctx->graph, "Placeholder", node_name.c_str()); + TF_SetAttrType(desc, "dtype", + static_cast(handle->dtype) /*TF_FLOAT*/); + auto* result = TF_FinishOperation(desc, status); + if (!status->status.ok()) { + return TF_Output{nullptr, -1}; + } + + auto ret = TF_Output{result, 0}; + VLOG(1) << "Creating a new map entry to map to: " + << getTF_OutputDebugString(ret); + trace_ctx->input_tensor_map[handle] = ret; + // `handle` could be destroyed before it's read from `input_tensor_map` (say + // during a subsequent TFE_FinalizeInputTensorsFromTraceContext() call), so we + // increment its ref count to extend its life span to that of `trace_ctx`. + handle->Ref(); + VLOG(1) << "Ref count for handle " << handle + << " is 1?: " << handle->RefCountIsOne(); + return ret; +} + +TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, + TFE_TensorHandle** retvals, + int* num_retvals, TF_Status* status) { + VLOG(1) << "Calling TFE_AddEagerOpToGraph() with op " << op << ": " + << op->operation.DebugString(); + + const auto& op_type = op->operation.Name(); + auto op_name = + tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++); + auto* desc = + TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()); + for (auto* input : op->operation.Inputs()) { + auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, input, status); + if (!status->status.ok()) return nullptr; + TF_AddInput(desc, symbolic_input); + } + + VLOG(1) << "Adding attrs."; + tensorflow::AttrValueMap attrs; + op->operation.Attrs().FillAttrValueMap(&attrs); + for (const auto& attr : attrs) { + desc->node_builder.Attr(attr.first, attr.second); + } + + auto* graph_op = TF_FinishOperation(desc, status); + if (!status->status.ok()) return nullptr; + + VLOG(1) << "Op finalized; setting return tensors."; + *num_retvals = TF_OperationNumOutputs(graph_op); + VLOG(1) << "This op has " << *num_retvals << " outputs."; + for (int i = 0; i < *num_retvals; ++i) { + auto output = TF_Output{graph_op, i}; + auto dtype = TF_OperationOutputType(output); + retvals[i] = TFE_NewTensorHandleFromTFOutput(output, dtype); + } + return graph_op; +} + +int TFE_FinalizeInputTensorsFromTraceContext(TFE_TraceContext* trace_ctx) { + if (trace_ctx->input_tensors == nullptr) { + trace_ctx->input_tensors = + new std::vector>(); + trace_ctx->input_tensors->reserve(trace_ctx->input_tensor_map.size()); + + for (auto input : trace_ctx->input_tensor_map) { + trace_ctx->input_tensors->emplace_back(input.first, input.second); + } + } + return trace_ctx->input_tensor_map.size(); +} + +TF_Output TFE_GetInputGraphNodeFromTraceContext(TFE_TraceContext* trace_ctx, + unsigned int idx) { + CHECK(trace_ctx->input_tensors != nullptr); + CHECK(trace_ctx->input_tensors->size() > idx); + return trace_ctx->input_tensors->at(idx).second; +} + +TFE_TensorHandle* TFE_ConsumeInputConcreteTensorFromTraceContext( + TFE_TraceContext* trace_ctx, unsigned int idx) { + CHECK(trace_ctx->input_tensors != nullptr); + CHECK(trace_ctx->input_tensors->size() > idx); + auto* handle = trace_ctx->input_tensors->at(idx).first; + VLOG(1) << "Ref count for internal handle " << handle + << " is 1?: " << handle->RefCountIsOne(); + handle->Ref(); + auto* ret = new TFE_TensorHandle(handle); + VLOG(1) << "Returning a new tensor handle " << ret << ": " + << handle->DebugString(); + return ret; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 80c8bfe594c4c89606efd01bec7f50e7a86b5bda..8d1a8b82fbaf9901b6d9aecf6d092ae298c8dba3 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -67,9 +67,10 @@ TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, // a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if // `enable_xla_compilation` is non-zero, and OFF otherwise. // b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`. +// c) ConfigProto.device_count is set to `num_cpu_devices`. TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig( - unsigned char enable_xla_compilation, - unsigned char gpu_memory_allow_growth); + unsigned char enable_xla_compilation, unsigned char gpu_memory_allow_growth, + unsigned int num_cpu_devices); // Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level // is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE @@ -83,6 +84,15 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions( TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, size_t* len); +// Returns the function content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +// +// Do not return const char*, because some foreign language binding +// (e.g. swift) cannot then call free() on the returned pointer. +TF_CAPI_EXPORT extern char* TF_FunctionDebugString(TF_Function* func, + size_t* len); + // Creates a stack of data set + iterator nodes, currently hard-coded to return // a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success, // returns the IteratorGetNext node, which caller can run or feed into an node. @@ -180,6 +190,8 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( TFE_TensorHandle* handle); +TF_CAPI_EXPORT extern void TFE_OpPrintDebugString(TFE_Op* op); + typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification; // Allows invoking a kernel asynchronously, and explicitly returns a @@ -239,13 +251,69 @@ TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv); // Platform-specific implementation to return an unused port. (This should used // in tests only.) -TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(); +TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void); // Fast path method that makes constructing a single scalar tensor require less // overhead and copies. TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar( TF_DataType dtype, void* scalar, size_t len); +// Specify the server_def that enables collective ops. +// This is different to the above function in that it doesn't create remote +// contexts, and remotely executing ops is not possible. It just enables +// communication for collective ops. +TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, + const void* proto, + size_t proto_len, + TF_Status* status); + +// Create a symbolic tensor from the input graph node. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromTFOutput( + TF_Output t, TF_DataType data_type); + +// Returns 0 if the input tensor handle represents a symbolic tensor (i.e., a +// graph node). Otherwise returns non-0. +TF_CAPI_EXPORT extern unsigned char TFE_TensorHandleIsConcrete( + TFE_TensorHandle* handle); + +// If `handle` is a symbolic tensor, return the corresponding graph node +// represented by TF_Output. Otherwise, return an error status. +TF_CAPI_EXPORT extern TF_Output TFE_GetTFOutputFromTensorHandle( + TFE_TensorHandle* handle, TF_Status* status); + +typedef struct TFE_TraceContext TFE_TraceContext; + +// A trace context contains a trace graph, to which TFE_AddEagerOpToGraph() +// calls add graph nodes as a way to symbolically execute the eager ops. +// +// It also contains a hash map from concrete input tensors to symbolic +// tensors. That map will be used to create input tensors to the trace graph. +TF_CAPI_EXPORT extern TFE_TraceContext* TFE_NewTraceContext(TF_Graph* graph); + +TF_CAPI_EXPORT extern void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx); + +// Symbolically executes `op`, by adding a corresponding node to the graph +// associated with `trace_ctx`. This graph node outputs a set of symbolic +// tensors in `retvals` and `num_retvals`. Returns the corresponding graph +// operation on success, otherwise returns nullptr. +TF_CAPI_EXPORT extern TF_Operation* TFE_AddEagerOpToGraph( + TFE_Op* op, TFE_TraceContext* trace_ctx, TFE_TensorHandle** retvals, + int* num_retvals, TF_Status* status); + +// Finalizes the trace graph and its inputs, and returns the number of inputs. +// After this call, the next two APIs can be called to iterate over the input +// tensors. +TF_CAPI_EXPORT extern int TFE_FinalizeInputTensorsFromTraceContext( + TFE_TraceContext* trace_ctx); + +TF_CAPI_EXPORT extern TF_Output TFE_GetInputGraphNodeFromTraceContext( + TFE_TraceContext* trace_ctx, unsigned int idx); + +// Each input tensor should be consumed at most once. +TF_CAPI_EXPORT extern TFE_TensorHandle* +TFE_ConsumeInputConcreteTensorFromTraceContext(TFE_TraceContext* trace_ctx, + unsigned int idx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index daa7701b7fe7e8ce757b6504329cf6434ad39778..354ee5f49f373edbc10e7706aa8776f3cc2a17cd 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_test_util.h" @@ -296,5 +297,154 @@ TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) { TF_DeleteStatus(status); } +TEST(CAPI_EXPERIMENTAL, SymbolicTensor) { + TF_Status* status = TF_NewStatus(); + auto node = TF_Output{nullptr, 1}; + auto* sym_handle = TFE_NewTensorHandleFromTFOutput(node, TF_FLOAT); + TFE_TensorHandlePrintDebugString(sym_handle); + CHECK_EQ(TFE_TensorHandleDataType(sym_handle), TF_FLOAT); + ASSERT_FALSE(TFE_TensorHandleIsConcrete(sym_handle)); + auto same_node = TFE_GetTFOutputFromTensorHandle(sym_handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(same_node.oper, node.oper); + ASSERT_EQ(same_node.index, node.index); + TFE_DeleteTensorHandle(sym_handle); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + ASSERT_TRUE(TFE_TensorHandleIsConcrete(m)); + (void)TFE_GetTFOutputFromTensorHandle(m, status); + CHECK_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(m); + + TF_DeleteStatus(status); +} + +class AddEagerOpToGraphTest : public ::testing::Test { + protected: + AddEagerOpToGraphTest() + : status_(TF_NewStatus()), + eager_ctx_(nullptr), + graph_(TF_NewGraph()), + trace_ctx_(TFE_NewTraceContext(graph_)) { + TFE_ContextOptions* opts = TFE_NewContextOptions(); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + eager_ctx_ = TFE_NewContext(opts, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_DeleteContextOptions(opts); + } + + ~AddEagerOpToGraphTest() override { + TFE_DeleteTraceContext(trace_ctx_); + TF_DeleteGraph(graph_); + TFE_DeleteContext(eager_ctx_); + TF_DeleteStatus(status_); + } + + template + void AddEagerOpToGraphAndCheck(TFE_Op* op, Callable checker) { + TFE_TensorHandle* retvals[5]; + int num_retvals = 5; + // Symbolically execute this op, which adds a graph node to `trace_ctx_`. + TF_Operation* graph_op = + TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_NOTNULL(graph_op); + // Check the expectations. + checker(graph_op); + for (int i = 0; i < num_retvals; ++i) { + TFE_DeleteTensorHandle(retvals[i]); + } + } + + TF_Status* status_; + TFE_Context* eager_ctx_; + TF_Graph* graph_; + TFE_TraceContext* trace_ctx_; +}; + +TEST_F(AddEagerOpToGraphTest, DebugPrintAndSymbolicExecution) { + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* op = MatMulOp(eager_ctx_, m, m); + + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpPrintDebugString(op); + + TFE_TensorHandle* retvals[5]; + int num_retvals = 5; + // Symbolically execute this op, which adds a graph node to `trace_ctx`. + TFE_AddEagerOpToGraph(op, trace_ctx_, retvals, &num_retvals, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + + int num_inputs = TFE_FinalizeInputTensorsFromTraceContext(trace_ctx_); + CHECK_EQ(num_inputs, 1); + auto input_sym_tensor = TFE_GetInputGraphNodeFromTraceContext(trace_ctx_, + /*idx*/ 0); + + LOG(INFO) << tensorflow::getTF_OutputDebugString(input_sym_tensor); + auto handle = TFE_ConsumeInputConcreteTensorFromTraceContext(trace_ctx_, + /*idx*/ 0); + TFE_TensorHandlePrintDebugString(handle); + TFE_DeleteTensorHandle(handle); + + CHECK_EQ(num_retvals, 1); + CHECK_EQ(TFE_TensorHandleDataType(retvals[0]), TF_FLOAT); + + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(m); + TFE_DeleteOp(op); +} + +TEST_F(AddEagerOpToGraphTest, ValueAttributesArePreserved) { + // Create MinOp + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* op = MinOp(eager_ctx_, axis, axis); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + + // Check the attributes set by the call to MinOp above. + AddEagerOpToGraphAndCheck(op, [this, &axis](TF_Operation* graph_op) { + unsigned char value; + TF_OperationGetAttrBool(graph_op, "keep_dims", &value, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(value, 1); + TF_DataType dtype; + TF_OperationGetAttrType(graph_op, "Tidx", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TF_INT32); + TF_OperationGetAttrType(graph_op, "T", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TFE_TensorHandleDataType(axis)); + }); + TFE_DeleteTensorHandle(axis); + TFE_DeleteOp(op); +} + +TEST_F(AddEagerOpToGraphTest, ListAttributesArePreserved) { + // Create a "Squeeze" operator with list attributes. + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* squeeze = TFE_NewOp(eager_ctx_, "Squeeze", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpAddInput(squeeze, axis, status_); + TFE_OpSetAttrType(squeeze, "T", TF_INT32); + std::vector boundaries = {1, 2, 3, 4}; + TFE_OpSetAttrIntList(squeeze, "squeeze_dims", boundaries.data(), + boundaries.size()); + // Check attributes are preserved. + AddEagerOpToGraphAndCheck( + squeeze, [this, &boundaries](TF_Operation* squeeze_graph_op) { + TF_DataType dtype; + TF_OperationGetAttrType(squeeze_graph_op, "T", &dtype, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(dtype, TF_INT32); + std::unique_ptr list(new int64_t[boundaries.size()]); + TF_OperationGetAttrIntList(squeeze_graph_op, "squeeze_dims", list.get(), + boundaries.size(), status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + EXPECT_TRUE(std::equal(list.get(), list.get() + boundaries.size(), + boundaries.begin())); + }); + TFE_DeleteTensorHandle(axis); + TFE_DeleteOp(squeeze); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 5ba26d3c585350aa510f9970cbfc246a9a108543..73283d775639b297857b2a50007dc7c28b1f39a3 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -228,6 +228,8 @@ void RecordMutation(TF_Graph* graph, const TF_Operation& op, bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) LOCKS_EXCLUDED(session->graph->mu, session->mu); +std::string getTF_OutputDebugString(TF_Output node); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index d5934a10395ae094f65d3bc8b6cd7b94dbd32410..2be03bf0de6277fc63c353ad6dc63bec096a6993 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -163,6 +163,7 @@ TEST(CAPI, AllocateTensor) { EXPECT_EQ(dims[0], TF_Dim(t, 0)); EXPECT_EQ(dims[1], TF_Dim(t, 1)); EXPECT_EQ(num_bytes, TF_TensorByteSize(t)); + EXPECT_EQ(6, TF_TensorElementCount(t)); TF_DeleteTensor(t); } @@ -1467,6 +1468,41 @@ TEST(CAPI, DeletingNullPointerIsSafe) { TF_DeleteStatus(status); } +TEST(CAPI, TestBitcastFrom_Reshape) { + int64_t dims[] = {2, 3}; + TF_Tensor* a = + TF_AllocateTensor(TF_UINT64, dims, 2, 6 * TF_DataTypeSize(TF_UINT64)); + TF_Tensor* b = + TF_AllocateTensor(TF_UINT64, nullptr, 0, TF_DataTypeSize(TF_UINT64)); + EXPECT_NE(a, nullptr); + EXPECT_NE(b, nullptr); + + EXPECT_EQ(6, TF_TensorElementCount(a)); + EXPECT_EQ(1, TF_TensorElementCount(b)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a)); + EXPECT_EQ(TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b)); + + int64_t new_dims[] = {3, 2}; + TF_Status* status = TF_NewStatus(); + TF_TensorBitcastFrom(a, TF_UINT64, b, new_dims, 2, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + + EXPECT_EQ(6, TF_TensorElementCount(a)); + EXPECT_EQ(6, TF_TensorElementCount(b)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a)); + EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b)); + + // Check that a write to one tensor shows up in the other. + *(static_cast(TF_TensorData(a))) = 4; + EXPECT_EQ(4, *(static_cast(TF_TensorData(b)))); + *(static_cast(TF_TensorData(b))) = 6; + EXPECT_EQ(6, *(static_cast(TF_TensorData(a)))); + + TF_DeleteTensor(a); + TF_DeleteTensor(b); +} + REGISTER_OP("TestOpWithNoGradient") .Input("x: T") .Output("y: T") diff --git a/tensorflow/c/c_test.c b/tensorflow/c/c_test.c new file mode 100644 index 0000000000000000000000000000000000000000..b86d8eb8e300e02a3871ecd5f424a82c521b18fc --- /dev/null +++ b/tensorflow/c/c_test.c @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/env.h" +#include "tensorflow/c/kernels.h" + +// A compute function. This will never actually get called in this test, it's +// just nice to know that it compiles. +void compute(void* kernel, TF_OpKernelContext* ctx) { + TF_Tensor* input; + TF_Status* s = TF_NewStatus(); + TF_GetInput(ctx, 0, &input, s); + TF_DeleteTensor(input); + + TF_DataType type; + TF_OpKernelContext_GetAttrType(ctx, "foobar", &type, s); + + TF_DeleteStatus(s); + +} + +// Exercises tensorflow's C API. +int main(int argc, char** argv) { + TF_InitMain(argv[0], &argc, &argv); + + struct TF_StringStream* s = TF_GetLocalTempDirectories(); + const char* path; + + if (!TF_StringStreamNext(s, &path)) { + fprintf(stderr, "TF_GetLocalTempDirectories returned no results\n"); + return 1; + } + + char file_name[100]; + struct timeval t; + if (gettimeofday(&t, NULL)) { + perror("gettimeofday failed"); + return 1; + } + snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t.tv_sec); + + size_t length = 2 + strlen(path) + strlen(file_name); + char* full_path = malloc(length); + snprintf(full_path, length, "%s/%s", path, file_name); + + TF_WritableFileHandle* h; + TF_Status* status = TF_NewStatus(); + TF_NewWritableFile(full_path, &h, status); + if (TF_GetCode(status) != TF_OK) { + fprintf(stderr, "TF_NewWritableFile failed: %s\n", TF_Message(status)); + return 1; + } + fprintf(stderr, "wrote %s\n", full_path); + free(full_path); + TF_CloseWritableFile(h, status); + if (TF_GetCode(status) != TF_OK) { + fprintf(stderr, "TF_CloseWritableFile failed: %s\n", TF_Message(status)); + } + TF_StringStreamDone(s); + + TF_KernelBuilder* b = + TF_NewKernelBuilder("SomeOp", "SomeDevice", NULL, &compute, NULL); + TF_RegisterKernelBuilder("someKernel", b, status); + + TF_DeleteStatus(status); + return 0; +} diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c34a84fcfee9b6ba9a7be86ae16e2856a2d343c7..257be6379c09841d1427813a0aa25b10a205016d 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -3,11 +3,19 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", - "tf_cuda_cc_test", - "tf_cc_test", "tf_copts", - "tfe_xla_copts", + "tf_cuda_cc_test", "tf_cuda_library", + "tfe_xla_copts", +) +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_additional_device_tracer_test_flags", + "tf_kernel_tests_linkstatic", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", ) tf_cuda_library( @@ -62,6 +70,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:remote_device", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core:gpu_runtime", ], ) @@ -101,6 +110,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/profiler/lib:profiler_session", ], ) @@ -148,6 +158,88 @@ tf_cuda_cc_test( ], ) +tf_cuda_library( + name = "c_api_experimental", + srcs = [ + "c_api_experimental.cc", + ], + hdrs = ["c_api_experimental.h"], + copts = tf_copts() + tfe_xla_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + ":c_api", + ":c_api_internal", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:copy_to_device_node", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], + }) + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:xla_device", + ], + "//conditions:default": [], + }) + [ + "@com_google_absl//absl/memory", + "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/profiler/rpc:profiler_server", + "//tensorflow/core:gpu_runtime", + ], +) + +tf_cuda_cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = [ + "c_api_experimental_test.cc", + ], + args = + ["--heap_check=local"] + tf_additional_device_tracer_test_flags(), + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":c_api_experimental", + ":c_api_test_util", + "//tensorflow/c:c_test_util", + "//tensorflow/cc/profiler", + "//tensorflow/contrib/tpu/profiler:trace_events_proto_cc", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/profiler:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "tape", hdrs = ["tape.h"], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 027d752f420238da867cb9d8c116640e1730caaa..af13f487af91594fedd4d5f77592682a6f98c34f 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -356,6 +356,8 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { if (h == nullptr) return; + VLOG(1) << "Deleting tensor handle " << h << " with internal handle " + << h->handle; if (h->handle) { h->handle->Unref(); } @@ -443,15 +445,15 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { return nullptr; } // TODO(agarwal): move this implementation inside TFE_TensorHandle. - tensorflow::Device* d = nullptr; - tensorflow::Device* op_device = nullptr; const tensorflow::Tensor* t = nullptr; - status->status = h->handle->TensorAndDevice(&t, &d, &op_device); - if (!status->status.ok()) return nullptr; tensorflow::TensorHandle* h_cpu = nullptr; - if (!IsCPU(d)) { - status->status = h->handle->CopyToDevice( - h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu); + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + + if (h->handle->IsRemote()) { + status->status = EagerCopyToDevice( + h->handle, h->handle->Context(), + h->handle->Context()->HostCPU()->name().c_str(), &h_cpu); if (!status->status.ok()) { return nullptr; } @@ -460,6 +462,22 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { h_cpu->Unref(); return nullptr; } + } else { + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; + + if (!IsCPU(d)) { + status->status = h->handle->CopyToDevice( + h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu); + if (!status->status.ok()) { + return nullptr; + } + status->status = h_cpu->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) { + h_cpu->Unref(); + return nullptr; + } + } } TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status); if (h_cpu != nullptr) { @@ -696,6 +714,7 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { + VLOG(1) << "Calling TFE_Execute() on op " << op; tensorflow::gtl::InlinedVector handle_retvals( *num_retvals); status->status = @@ -738,6 +757,10 @@ void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, status->status = ctx->context.AddFunctionDef(function->fdef); } +unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { + return ctx->context.FindFunctionDef(name) != nullptr; +} + void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { ctx->context.SetShouldStoreMetadata(true); } @@ -774,7 +797,7 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, if (!status->status.ok()) return; tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); - ctx->context.RunMetadataProto()->Clear(); + ctx->context.ClearRunMetadata(); } namespace { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 8d6c8d958d5961fce817156a14eb2b2940c1f2f0..044dfb7415b027b707af05a197fdb41fe1f6d2e5 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -48,7 +48,7 @@ extern "C" { typedef struct TFE_ContextOptions TFE_ContextOptions; // Return a new options object. -TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(); +TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(void); // Set the config in TF_ContextOptions.options. // config should be a serialized tensorflow.ConfigProto proto. @@ -170,23 +170,11 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status); -// Returns the device of the operation that produced `h`. -// If `h` was produced by a copy, returns the destination device of -// the copy. Note that returned device name is not always the device -// holding the tensor handle's memory. If you want the latter, use -// TFE_TensorHandleBackingDeviceName. -// This function will block till the operation that produces `h` has completed. -// -// Device on which the kernel of the operation that produced `h` ran. -// -// If `h` was produced by a copy, returns the destination device of -// the copy. -// -// Note that returned device name is not always the device that owns the memory -// that backs the tensor handle. For the latter see -// TFE_TensorHandleBackingDeviceName. -// -// This function will block till the operation that produces `h` has completed. +// Returns the device of the operation that produced `h`. If `h` was produced by +// a copy, returns the destination device of the copy. Note that the returned +// device name is not always the device holding the tensor handle's memory. If +// you want the latter, use TFE_TensorHandleBackingDeviceName. This function +// will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); @@ -405,6 +393,10 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status); +// Checks whether a function is registered under `name`. +TF_CAPI_EXPORT unsigned char TFE_ContextHasFunction(TFE_Context* ctx, + const char* name); + // Enables tracing of RunMetadata on the ops executed from this context. TF_CAPI_EXPORT extern void TFE_ContextEnableRunMetadata(TFE_Context* ctx); diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index 52b0824552855860dfb138f3ac9a5d3afa7dc965..ffcd5ace0b98597363abe63201bf6c328a03212f 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -83,7 +83,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( } } - if (xla::ShapeUtil::IsTuple(padded_shape)) { + if (padded_shape.IsTuple()) { if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { // Currently, the only case of XlaTensor containing a tuple shape is to // represent 64 bit ints, doubles, and complex numbers (we don't support @@ -99,7 +99,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); const xla::Shape& shape1 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); - if (xla::ShapeUtil::IsTuple(shape0) || xla::ShapeUtil::IsTuple(shape1)) { + if (shape0.IsTuple() || shape1.IsTuple()) { status->status = tensorflow::errors::InvalidArgument( "XlaTensors should not contain nested tuples. Shape: ", padded_shape.DebugString()); diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc new file mode 100644 index 0000000000000000000000000000000000000000..dab17505643e791e6294a64247898ae23769a055 --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_experimental.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/profiler/rpc/profiler_server.h" + +using tensorflow::string; + +void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { + op->operation.ConsumeInput(h->handle); +} + +TFE_Profiler* TFE_NewProfiler(TFE_Context* ctx) { + return new TFE_Profiler(ctx); +} + +void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; } + +void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler, + TF_Buffer* buf, TF_Status* status) { + TFE_ContextAsyncWait(ctx, status); + if (!status->status.ok()) return; + string content; + status->status = profiler->profiler->SerializeToString(&content); + void* data = tensorflow::port::Malloc(content.length()); + content.copy(static_cast(data), content.length(), 0); + buf->data = data; + buf->length = content.length(); + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; +} + +void TFE_StartProfilerServer(TFE_Context* ctx, int port) { + auto server_thread = tensorflow::StartProfilerServer(&ctx->context, port); + ctx->context.AddChildThread(std::move(server_thread)); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h new file mode 100644 index 0000000000000000000000000000000000000000..8c85d0e51695fde09cf0e2bb3930f9173e6cfb54 --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, + TF_Status* status); + +// A profiler which will start profiling when creating the object and will stop +// when the object is destroyed. It will profile all operations run under the +// given TFE_Context. Multiple instance of it can be created, but at most one +// of them will profile for each TFE_Context. +// Thread-safety: TFE_Profiler is thread-safe. +typedef struct TFE_Profiler TFE_Profiler; + +TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(TFE_Context* ctx); +TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler); + +// The output string is a binary string of tensorflow.tpu.Trace. User can write +// the string to file for offline analysis by tensorboard. +TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Context* ctx, + TFE_Profiler* profiler, + TF_Buffer* buf, + TF_Status* status); + +// Start a profiler grpc server which listens to specified port. It will start +// the server on its own thread. It can be shutdown by destructing TFE_Context. +// Creating multiple profiler server is allowed. The service defined in +// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use +// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture tracable +// file following +// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace. +TF_CAPI_EXPORT extern void TFE_StartProfilerServer(TFE_Context* ctx, int port); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..af55fee66e8708e39626da3b10b6dd2f73af92bb --- /dev/null +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_experimental.h" + +#include +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/cc/profiler/profiler.h" +#include "tensorflow/contrib/tpu/profiler/trace_events.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +using tensorflow::string; + +namespace tensorflow { +namespace { + +static bool HasSubstr(absl::string_view base, absl::string_view substr) { + bool ok = str_util::StrContains(base, substr); + EXPECT_TRUE(ok) << base << ", expected substring " << substr; + return ok; +} + +void ExecuteWithProfiling(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + TFE_Profiler* profiler = TFE_NewProfiler(ctx); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + + // Run op on GPU if it is present. + string gpu_device_name; + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { + TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + const char* device_name = TFE_OpGetDevice(matmul, status); + ASSERT_TRUE(strstr(device_name, "GPU:0") != nullptr); + } + + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TF_Buffer* profiler_result = TF_NewBuffer(); + TFE_ProfilerSerializeToString(ctx, profiler, profiler_result, status); + TFE_DeleteProfiler(profiler); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + tensorflow::tpu::Trace profile_proto; + EXPECT_TRUE(profile_proto.ParseFromString( + {reinterpret_cast(profiler_result->data), + profiler_result->length})); + string profile_proto_str = profile_proto.DebugString(); + if (!gpu_device_name.empty()) { + EXPECT_TRUE(HasSubstr(profile_proto_str, "GPU:0")); + // device name with "stream:all" is collected by Device Tracer. + EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all")); + } + EXPECT_TRUE(HasSubstr(profile_proto_str, "CPU:0")); + TF_DeleteBuffer(profiler_result); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + TF_DeleteStatus(status); +} +TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); } +TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); } + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 67bc1bcd24605f8363d6a7c8d5d6a0836a42fc82..b70c0f1c112c675641a023d6c7bf4fa847ee4610 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/profiler/lib/profiler_session.h" #include "tensorflow/core/public/version.h" struct TFE_ContextOptions { @@ -82,6 +83,12 @@ struct TFE_TensorHandle { TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} tensorflow::TensorHandle* handle; + + // Create a symbolic tensor. + TFE_TensorHandle(TF_Output t, TF_DataType dtype) + : handle(new tensorflow::TensorHandle( + tensorflow::OutputGraphNode{t.oper, t.index}, + static_cast(dtype))) {} }; struct TFE_TensorDebugInfo { @@ -100,6 +107,13 @@ struct TFE_Op { tensorflow::EagerOperation operation; }; +struct TFE_Profiler { + TFE_Profiler(TFE_Context* ctx) + : profiler(tensorflow::ProfilerSession::Create(&ctx->context)) {} + + std::unique_ptr profiler; +}; + namespace tensorflow { // Set an AttrValue on the op. Doesn't handle the list types. void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, @@ -107,4 +121,24 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const char* attr_name, TF_Status* status); } // namespace tensorflow +struct TFE_TraceContext { + TF_Graph* const graph; + + unsigned int node_counter = 0; + // Each tensor handle will have its ref count incremented when it's added as a + // map key, and decremented when this object is destroyed. + std::map input_tensor_map; + std::vector>* input_tensors = + nullptr; + + TFE_TraceContext(TF_Graph* graph) : graph(graph) {} + + ~TFE_TraceContext() { + delete input_tensors; + for (auto input : input_tensor_map) { + input.first->Unref(); + } + } +}; + #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 6b39b79ee82f9c7baaf856e573a42b7da65691e5..3d1ca4fb4b561a03ea9d879b1876fb1fd08a3139 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -175,13 +175,8 @@ void TestRemoteExecute(bool async) { TFE_Execute(matmul, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - auto* retval_task0 = TFE_TensorHandleCopyToDevice( - retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteTensorHandle(retval_task0); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c35ff9001d0ee1ab0fbae9e1bcc07116fab1065 --- /dev/null +++ b/tensorflow/c/env.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +struct TF_StringStream { + std::vector<::tensorflow::string>* list; + size_t position; +}; + +void TF_CreateDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->CreateDir(dirname)); +} + +void TF_DeleteDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteDir(dirname)); +} + +void TF_DeleteRecursively(const char* dirname, int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, TF_Status* status) { + ::tensorflow::int64 f, d; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteRecursively(dirname, &f, &d)); + *undeleted_file_count = f; + *undeleted_dir_count = d; +} + +void TF_FileStat(const char* filename, TF_FileStatistics* stats, + TF_Status* status) { + ::tensorflow::FileStatistics cc_stats; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->Stat(filename, &cc_stats); + ::tensorflow::Set_TF_Status_from_Status(status, s); + if (s.ok()) { + stats->length = cc_stats.length; + stats->mtime_nsec = cc_stats.mtime_nsec; + stats->is_directory = cc_stats.is_directory; + } +} + +void TF_NewWritableFile(const char* filename, TF_WritableFileHandle** handle, + TF_Status* status) { + std::unique_ptr<::tensorflow::WritableFile> f; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->NewWritableFile(filename, &f); + ::tensorflow::Set_TF_Status_from_Status(status, s); + + if (s.ok()) { + *handle = reinterpret_cast(f.release()); + } +} + +void TF_CloseWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Close()); + delete cc_file; +} + +void TF_SyncWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Sync()); +} + +void TF_FlushWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Flush()); +} + +void TF_AppendWritableFile(TF_WritableFileHandle* handle, const char* data, + size_t length, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, cc_file->Append(::tensorflow::StringPiece{data, length})); +} + +void TF_DeleteFile(const char* filename, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteFile(filename)); +} + +bool TF_StringStreamNext(TF_StringStream* list, const char** result) { + if (list->position >= list->list->size()) { + *result = nullptr; + return false; + } + + *result = list->list->at(list->position++).c_str(); + return true; +} + +void TF_StringStreamDone(TF_StringStream* list) { + delete list->list; + delete list; +} +TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) { + auto* children = new std::vector<::tensorflow::string>; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->GetChildren(dirname, children)); + + auto* list = new TF_StringStream; + list->list = children; + list->position = 0; + return list; +} + +TF_StringStream* TF_GetLocalTempDirectories() { + auto* tmpdirs = new std::vector<::tensorflow::string>; + + ::tensorflow::Env::Default()->GetLocalTempDirectories(tmpdirs); + + auto* list = new TF_StringStream; + list->list = tmpdirs; + list->position = 0; + return list; +} + +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) { + return ::tensorflow::Env::Default()->NowNanos(); +} + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) { + return ::tensorflow::Env::Default()->NowMicros(); +} + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) { + return ::tensorflow::Env::Default()->NowSeconds(); +} + +void TF_DefaultThreadOptions(TF_ThreadOptions* options) { + options->stack_size = 0; + options->guard_size = 0; + options->numa_node = -1; +} + +TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, void (*work_func)(void*), + void* param) { + ::tensorflow::ThreadOptions cc_options; + cc_options.stack_size = options->stack_size; + cc_options.guard_size = options->guard_size; + cc_options.numa_node = options->numa_node; + return reinterpret_cast(::tensorflow::Env::Default()->StartThread( + cc_options, thread_name, [=]() { (*work_func)(param); })); +} + +void TF_JoinThread(TF_Thread* thread) { + // ::tensorflow::Thread joins on destruction + delete reinterpret_cast<::tensorflow::Thread*>(thread); +} diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h new file mode 100644 index 0000000000000000000000000000000000000000..73078fcbbc5ae4c042f4a992655072a838e42915 --- /dev/null +++ b/tensorflow/c/env.h @@ -0,0 +1,195 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#ifndef TENSORFLOW_C_ENV_H_ +#define TENSORFLOW_C_ENV_H_ + +#include "tensorflow/c/c_api.h" + +// -------------------------------------------------------------------------- +// C API for tensorflow::Env. + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_WritableFileHandle TF_WritableFileHandle; +typedef struct TF_StringStream TF_StringStream; +typedef struct TF_Thread TF_Thread; + +typedef struct TF_FileStatistics { + // The length of the file in bytes. + int64_t length; + // The last modified time in nanoseconds. + int64_t mtime_nsec; + // Whether the name refers to a directory. + bool is_directory; +} TF_FileStatistics; + +typedef struct TF_ThreadOptions { + // Thread stack size to use (in bytes), zero implies that the system default + // will be used. + size_t stack_size; + + // Guard area size to use near thread stacks to use (in bytes), zero implies + // that the system default will be used. + size_t guard_size; + + // The NUMA node to use, -1 implies that there should be no NUMA affinity for + // this thread. + int numa_node; +} TF_ThreadOptions; + +// Creates the specified directory. Typical status code are: +// * TF_OK - successfully created the directory +// * TF_ALREADY_EXISTS - directory already exists +// * TF_PERMISSION_DENIED - dirname is not writable +TF_CAPI_EXPORT extern void TF_CreateDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory. Typical status codes are: +// * TF_OK - successfully deleted the directory +// * TF_FAILED_PRECONDITION - the directory is not empty +TF_CAPI_EXPORT extern void TF_DeleteDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory and all subdirectories and files underneath +// it. This is accomplished by traversing the directory tree rooted at dirname +// and deleting entries as they are encountered. +// +// If dirname itself is not readable or does not exist, *undeleted_dir_count is +// set to 1, *undeleted_file_count is set to 0 and an appropriate status (e.g. +// TF_NOT_FOUND) is returned. +// +// If dirname and all its descendants were successfully deleted, TF_OK is +// returned and both error counters are set to zero. +// +// Otherwise, while traversing the tree, undeleted_file_count and +// undeleted_dir_count are updated if an entry of the corresponding type could +// not be deleted. The returned error status represents the reason that any one +// of these entries could not be deleted. +// +// Typical status codes: +// * TF_OK - dirname exists and we were able to delete everything underneath +// * TF_NOT_FOUND - dirname doesn't exist +// * TF_PERMISSION_DENIED - dirname or some descendant is not writable +// * TF_UNIMPLEMENTED - some underlying functions (like Delete) are not +// implemented +TF_CAPI_EXPORT extern void TF_DeleteRecursively(const char* dirname, + int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, + TF_Status* status); + +// Obtains statistics for the given path. If status is TF_OK, *stats is +// updated, otherwise it is not touched. +TF_CAPI_EXPORT extern void TF_FileStat(const char* filename, + TF_FileStatistics* stats, + TF_Status* status); + +// Creates or truncates the given filename and returns a handle to be used for +// appending data to the file. If status is TF_OK, *handle is updated and the +// caller is responsible for freeing it (see TF_CloseWritableFile). +TF_CAPI_EXPORT extern void TF_NewWritableFile(const char* filename, + TF_WritableFileHandle** handle, + TF_Status* status); + +// Closes the given handle and frees its memory. If there was a problem closing +// the file, it is indicated by status. Memory is freed in any case. +TF_CAPI_EXPORT extern void TF_CloseWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Syncs content of the handle to the filesystem. Blocks waiting for the +// filesystem to indicate that the content has been persisted. +TF_CAPI_EXPORT extern void TF_SyncWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Flush local buffers to the filesystem. If the process terminates after a +// successful flush, the contents may still be persisted, since the underlying +// filesystem may eventually flush the contents. If the OS or machine crashes +// after a successful flush, the contents may or may not be persisted, depending +// on the implementation. +TF_CAPI_EXPORT extern void TF_FlushWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Appends the given bytes to the file. Any failure to do so is indicated in +// status. +TF_CAPI_EXPORT extern void TF_AppendWritableFile(TF_WritableFileHandle* handle, + const char* data, + size_t length, + TF_Status* status); + +// Deletes the named file and indicates whether successful in *status. +TF_CAPI_EXPORT extern void TF_DeleteFile(const char* filename, + TF_Status* status); + +// Retrieves the next item from the given TF_StringStream and places a pointer +// to it in *result. If no more items are in the list, *result is set to NULL +// and false is returned. +// +// Ownership of the items retrieved with this function remains with the library. +// Item points are invalidated after a call to TF_StringStreamDone. +TF_CAPI_EXPORT extern bool TF_StringStreamNext(TF_StringStream* list, + const char** result); + +// Frees the resources associated with given string list. All pointers returned +// by TF_StringStreamNext are invalid after this call. +TF_CAPI_EXPORT extern void TF_StringStreamDone(TF_StringStream* list); + +// Retrieves the list of children of the given directory. You can iterate +// through the list with TF_StringStreamNext. The caller is responsible for +// freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename, + TF_Status* status); + +// Retrieves a list of directory names on the local machine that may be used for +// temporary storage. You can iterate through the list with TF_StringStreamNext. +// The caller is responsible for freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void); + +// Returns the number of nanoseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void); + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void); + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void); + +// Populates a TF_ThreadOptions struct with system-default values. +TF_CAPI_EXPORT extern void TF_DefaultThreadOptions(TF_ThreadOptions* options); + +// Returns a new thread that is running work_func and is identified +// (for debugging/performance-analysis) by thread_name. +// +// The given param (which may be null) is passed to work_func when the thread +// starts. In this way, data may be passed from the thread back to the caller. +// +// Caller takes ownership of the result and must call TF_JoinThread on it +// eventually. +TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, + void (*work_func)(void*), + void* param); + +// Waits for the given thread to finish execution, then deletes it. +TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread); + +#ifdef __cplusplus +} +#endif + +#endif // TENSORFLOW_C_ENV_H_ diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..687ad024137352662759ec1f43df87e89faca353 --- /dev/null +++ b/tensorflow/c/env_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) + +TEST(TestEnv, TestDirHandling) { + TF_StringStream* tempdirs = TF_GetLocalTempDirectories(); + const char* tempdir; + bool found = false; + while (TF_StringStreamNext(tempdirs, &tempdir)) { + found = true; + + TF_Status* s = TF_NewStatus(); + + ::tensorflow::string dirpath = + ::tensorflow::io::JoinPath(tempdir, "somedir"); + TF_CreateDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_CreateDir failed for " << dirpath << ": " + << TF_Message(s); + + ::tensorflow::string filepath = + ::tensorflow::io::JoinPath(dirpath, "somefile.txt"); + TF_WritableFileHandle* handle; + TF_NewWritableFile(filepath.c_str(), &handle, s); + ASSERT_TF_OK(s) << "NewWritableFile failed for " << filepath << ": " + << TF_Message(s); + + const char* data = "Hello, world!\n"; + TF_AppendWritableFile(handle, data, strlen(data), s); + ASSERT_TF_OK(s) << "TF_AppendWritableFile failed to append data to file at " + << filepath << ": " << TF_Message(s); + + TF_CloseWritableFile(handle, s); + ASSERT_TF_OK(s) << "TF_CloseWritableFile failed to close handle to " + << filepath << ": " << TF_Message(s); + + TF_StringStream* children = TF_GetChildren(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_GetChildren failed for " << dirpath; + const char* childpath; + ASSERT_TRUE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(::tensorflow::string(childpath), "somefile.txt"); + // There should only be one file in this directory. + ASSERT_FALSE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(childpath, nullptr); + TF_StringStreamDone(children); + + TF_FileStatistics stats; + TF_FileStat(filepath.c_str(), &stats, s); + ASSERT_EQ(stats.length, strlen(data)); + ASSERT_FALSE(stats.is_directory); + ASSERT_GT(stats.mtime_nsec, 0); + + // Trying to delete a non-empty directory should fail. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_NE(TF_OK, TF_GetCode(s)) + << "TF_DeleteDir unexpectedly succeeded with a non-empty directory " + << dirpath; + + TF_DeleteFile(filepath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteFile failed for " << filepath << ": " + << TF_Message(s); + + // Now deleting the directory should work. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteDir failed for " << dirpath << ": " + << TF_Message(s); + + TF_DeleteStatus(s); + break; + } + + ASSERT_TRUE(found) << "expected at least one temp dir"; + + TF_StringStreamDone(tempdirs); +} + +TEST(TestEnv, TestTimeFunctions) { + ASSERT_GE(TF_NowSeconds(), 946684800); // Midnight Jan 1, 2000 + ASSERT_GE(TF_NowMicros(), 946684800 * 1e6); + ASSERT_GE(TF_NowNanos(), 946684800 * 1e9); +} + +namespace { + +struct SomeThreadData { + ::tensorflow::mutex mu; + bool did_work = false; +}; + +void SomeThreadFunc(void* data) { + auto* real_data = static_cast(data); + ::tensorflow::mutex_lock l(real_data->mu); + real_data->did_work = true; +} + +} // namespace + +TEST(TestEnv, TestThreads) { + TF_ThreadOptions options; + TF_DefaultThreadOptions(&options); + SomeThreadData data; + TF_Thread* thread = + TF_StartThread(&options, "SomeThreadName", &SomeThreadFunc, &data); + TF_JoinThread(thread); + ::tensorflow::mutex_lock l(data.mu); + ASSERT_TRUE(data.did_work); +} diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index ca69345264607ac689fb556b4f5c9bc08ea5eb88..9505bf9dda32b9a338b574f1d31ec555a5628c6a 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -15,7 +15,9 @@ limitations under the License. #include +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/kernels.h" +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -46,9 +48,10 @@ TF_KernelBuilder* TF_NewKernelBuilder( } void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) { - DCHECK_NE(builder, nullptr); - delete builder->cc_builder; - delete builder; + if (builder != nullptr) { + delete builder->cc_builder; + delete builder; + } } namespace tensorflow { @@ -116,3 +119,84 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, TF_SetStatus(status, TF_OK, ""); } + +int TF_NumInputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_inputs(); +} + +int TF_NumOutputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_outputs(); +} + +void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (i < 0 || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + return; + } + const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i)); + TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status); + if (TF_GetCode(status) == TF_OK) { + *tensor = result; + } +} + +void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (i < 0 || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + return; + } + ::tensorflow::Tensor cc_tensor; + ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, s); + if (s.ok()) { + cc_ctx->set_output(i, cc_tensor); + } +} + +void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); + ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status)); + cc_ctx->CtxFailure(s); +} + +void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + ::tensorflow::Status s(::tensorflow::StatusFromTF_Status(status)); + cc_ctx->CtxFailure(s); +} + +#define DEFINE_TF_GETATTR_(struct_name, func, c_type, cc_type) \ + void struct_name##_GetAttr##func(struct_name* ctx, const char* attr_name, \ + c_type* val, TF_Status* status) { \ + TF_SetStatus(status, TF_OK, ""); \ + cc_type v; \ + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \ + ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); \ + ::tensorflow::Set_TF_Status_from_Status(status, s); \ + if (s.ok()) { \ + *val = static_cast(v); \ + } \ + } + +#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ + DEFINE_TF_GETATTR_(TF_OpKernelConstruction, func, c_type, cc_type) \ + DEFINE_TF_GETATTR_(TF_OpKernelContext, func, c_type, cc_type) + +DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) + +TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return static_cast(cc_ctx->expected_output_dtype(i)); +} + +int64_t TF_StepId(TF_OpKernelContext* ctx) { + return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id(); +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 2518789a3c141755d0b3373d53642c487331f68b..b015d0103969355e8566242bfcc007f697c6ae18 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -35,9 +35,9 @@ extern "C" { // `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided // kernels when necessary. -struct TF_KernelBuilder; -struct TF_OpKernelConstruction; -struct TF_OpKernelContext; +typedef struct TF_KernelBuilder TF_KernelBuilder; +typedef struct TF_OpKernelConstruction TF_OpKernelConstruction; +typedef struct TF_OpKernelContext TF_OpKernelContext; // Allocates a new kernel builder and returns a pointer to it. // @@ -85,6 +85,67 @@ TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name, // builder is not registered with TensorFlow via TF_RegisterKernelBuilder. TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); +// -------------------------------------------------------------------------- +// OpKernelContext routines + +// TF_NumInputs returns the number of inputs available in ctx. +TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx); + +// TF_NumOutputs returns the number of outputs to be placed in *ctx by the +// kernel. +TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx); + +// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is +// populated and its ownership is passed to the caller. In any other case, +// *tensor is not modified. +// +// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i, + TF_Tensor** tensor, TF_Status* status); + +// Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but +// TF_OK, ctx is left unmodified. +// +// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i, + const TF_Tensor* tensor, + TF_Status* status); + +// Notifies the given OpKernelConstruction that kernel construction has failed. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_Failure( + TF_OpKernelConstruction* ctx, TF_Status* status); + +// Notifies the given OpKernelContext that the kernel's compute function has +// failed. +TF_CAPI_EXPORT extern void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, + TF_Status* status); + +// Returns the expected output data type of the ith output. If i < 0 or +// i >= TF_NumOutputs(ctx), the program aborts. +TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType( + TF_OpKernelContext* ctx, int i); + +// Returns the step ID of the given context. +TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx); + +// Interprets the named kernel construction attribute as a TF_DataType and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// TF_DataType, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrType( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* val, + TF_Status* status); + +// Interprets the named kernel context attribute as a TF_DataType and places it +// into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// TF_DataType, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelContext_GetAttrType( + TF_OpKernelContext* ctx, const char* attr_name, TF_DataType* val, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index e706c7c1d96ee1781d8efc0f28c5e0cbcbc80861..0d2954717e7a83c102a35815809a554e3a917e07 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/c/kernels.h" +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/op.h" @@ -31,7 +33,6 @@ struct MyCustomKernel { static bool delete_called = false; static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { - LOG(INFO) << "Wow, actually got into creation"; struct MyCustomKernel* s = new struct MyCustomKernel; s->created = true; s->compute_called = false; @@ -41,6 +42,19 @@ static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) { struct MyCustomKernel* s = static_cast(kernel); s->compute_called = true; + if (ctx != nullptr) { + TF_Status* status = TF_NewStatus(); + + EXPECT_EQ(43, TF_StepId(ctx)); + + // Exercise attribute reads. + TF_DataType type; + TF_OpKernelContext_GetAttrType(ctx, "SomeDataTypeAttr", &type, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + EXPECT_EQ(TF_FLOAT, type); + + TF_DeleteStatus(status); + } } static void MyDeleteFunc(void* kernel) { @@ -51,12 +65,37 @@ static void MyDeleteFunc(void* kernel) { delete s; } +namespace tensorflow { + +static std::unique_ptr GetFakeKernel(const char* device_name, + const char* op_name, + Status* status) { + NodeDef def; + def.set_op(op_name); + def.set_device(device_name); + def.add_input("input1"); + def.add_input("input2"); + + AttrValue v; + v.set_type(DataType::DT_FLOAT); + (*def.mutable_attr())["SomeDataTypeAttr"] = v; + + return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1, + status); +} + // Tests registration of a single C kernel and checks that calls through the // C/C++ boundary are being made. TEST(TestKernel, TestRegisterKernelBuilder) { const char* kernel_name = "SomeKernelName"; const char* op_name = "FooOp"; - const char* device_name = "barDev"; + const char* device_name = "FakeDeviceName1"; + + REGISTER_OP(op_name) + .Input("input1: double") + .Input("input2: uint8") + .Output("output1: uint8") + .Attr("SomeDataTypeAttr: type"); TF_KernelBuilder* builder = TF_NewKernelBuilder( op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); @@ -65,35 +104,128 @@ TEST(TestKernel, TestRegisterKernelBuilder) { TF_Status* status = TF_NewStatus(); TF_RegisterKernelBuilder(kernel_name, builder, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); - TF_Buffer* buf = TF_GetRegisteredKernelsForOp("FooOp", status); + TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); - ::tensorflow::KernelList list; + KernelList list; list.ParseFromArray(buf->data, buf->length); ASSERT_EQ(1, list.kernel_size()); - ASSERT_EQ("barDev", list.kernel(0).device_type()); + ASSERT_EQ(device_name, list.kernel(0).device_type()); TF_DeleteBuffer(buf); TF_DeleteStatus(status); } - REGISTER_OP("FooOp") + { + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + kernel->Compute(nullptr); + } + + ASSERT_TRUE(delete_called); +} + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool RequiresRecordingAccessedTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +TEST(TestKernel, TestInputAndOutputCount) { + const char* kernel_name = "InputOutputCounterKernel"; + const char* op_name = "BarOp"; + const char* device_name = "FakeDeviceName2"; + + REGISTER_OP(op_name) .Input("input1: double") .Input("input2: uint8") - .Output("output1: uint8"); + .Output("output1: uint8") + .Attr("SomeDataTypeAttr: type"); + + static int num_inputs = 0; + static int num_outputs = 0; + + // A kernel whose Compute function has a side-effect of updating num_inputs + // and num_outputs. Various functions on TF_OpKernelContext are also + // exercised. + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + num_inputs = TF_NumInputs(ctx); + num_outputs = TF_NumOutputs(ctx); + + TF_Tensor* input = nullptr; + TF_Status* s = TF_NewStatus(); + TF_GetInput(ctx, 0, &input, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s); + EXPECT_EQ(123, *static_cast(TF_TensorData(input))); + TF_GetInput(ctx, -1, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + TF_GetInput(ctx, 3, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + + // Copy the input tensor to output. + TF_SetOutput(ctx, 0, input, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + + TF_SetOutput(ctx, 24, input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + + EXPECT_EQ(TF_UINT8, TF_ExpectedOutputDataType(ctx, 0)); + + TF_DeleteStatus(s); + if (input != nullptr) { + TF_DeleteTensor(input); + } + }; + + TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr, + my_compute_func, nullptr); { - ::tensorflow::NodeDef def; - def.set_op("FooOp"); - def.set_device("bar"); - def.add_input("input1"); - def.add_input("input2"); - ::tensorflow::Status status; - std::unique_ptr<::tensorflow::OpKernel> kernel = - ::tensorflow::CreateOpKernel(::tensorflow::DeviceType("barDev"), - nullptr, nullptr, def, 1, &status); + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(kernel_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + } + + { + OpKernelContext::Params p; + DummyDevice dummy_device(nullptr, false); + p.device = &dummy_device; + p.step_id = 43; + + Tensor t(tensorflow::uint8(123)); + + gtl::InlinedVector inputs; + // Simulate 2 inputs + inputs.emplace_back(&t); + inputs.emplace_back(); + p.inputs = &inputs; + + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); TF_EXPECT_OK(status); ASSERT_NE(nullptr, kernel.get()); - kernel->Compute(nullptr); + + p.op_kernel = kernel.get(); + OpKernelContext ctx(&p); + kernel->Compute(&ctx); + + ASSERT_EQ(2, num_inputs); + ASSERT_EQ(1, num_outputs); + ASSERT_EQ(123, ctx.mutable_output(0)->scalar()()); } +} - ASSERT_TRUE(delete_called); +TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) { + TF_DeleteKernelBuilder(nullptr); } + +} // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 247236b760dd8c07bbb08426100b6a4d34296d2e..98d8393332269ae349cf8aa5c0b612c6f17172e6 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -160,4 +160,17 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); } +void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status) { + mutex_lock l(graph->mu); + status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, + new_src.index, &dst->node); + if (status->status.ok()) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst, "adding input tensor"); + } +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 5cce84020bc68d912d259f51512341eb5f464a2c..44779ca656165dd65590cb5e9ea3ccf71165ed63 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -34,6 +34,7 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); +// Updates 'dst' to consume 'new_src'. void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); @@ -65,6 +66,13 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output); // because I couldn't get SWIG to work otherwise. void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, size_t proto_len, TF_Status* status); + +// This method is used to add a new input edge to 'dst', which must be a While +// op. The While op's "T" attribute must have already been updated to include +// the new edge. This is used to construct tf.while_loop gradients. +void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status); + } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc index 882709e1e2817431a32c453fe0f35f2b2e6c69b0..05c287bdc62cdb8be7208ce3975f280aaa816766 100644 --- a/tensorflow/cc/gradients/image_grad.cc +++ b/tensorflow/cc/gradients/image_grad.cc @@ -69,6 +69,23 @@ Status ResizeBicubicGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("ResizeBicubic", ResizeBicubicGradHelper); +Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + string kernel_type; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type)); + grad_outputs->push_back(internal::ScaleAndTranslateGrad( + scope, grad_inputs[0], op.input(0), op.input(2), op.input(3), + internal::ScaleAndTranslateGrad::KernelType(kernel_type))); + + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ScaleAndTranslate", ScaleAndTranslateGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc index 2e55c7561b030c50bd67bd53fd0d55710085c5d2..1d150226538093467e092e02f38090a327f9c9b6 100644 --- a/tensorflow/cc/gradients/image_grad_test.cc +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -30,6 +30,7 @@ using ops::Const; using ops::ResizeBicubic; using ops::ResizeBilinear; using ops::ResizeNearestNeighbor; +using ops::ScaleAndTranslate; class ImageGradTest : public ::testing::Test { protected: @@ -153,5 +154,45 @@ TEST_F(ImageGradTest, TestBicubic) { TestResize(RESIZE_BICUBIC); } +class ScaleAndTranslateGradTest : public ::testing::Test { + protected: + ScaleAndTranslateGradTest() : scope_(Scope::NewRootScope()) {} + + template + Tensor MakeData(const TensorShape& data_shape) { + DataType data_type = DataTypeToEnum::v(); + Tensor data(data_type, data_shape); + auto data_flat = data.flat(); + for (int i = 0; i < data_flat.size(); ++i) { + data_flat(i) = T(i); + } + return data; + } + + template + void MakeOp(const Tensor& x_data, const Input& y_shape, Output* x, + Output* y) { + *x = Const(scope_, x_data); + *y = ScaleAndTranslate(scope_, *x, y_shape, {1.8f, 2.1f}, {0.5f, 0.7f}); + TF_ASSERT_OK(scope_.status()); + } + + template + void TestResize() { + TensorShape x_shape({1, 2, 3, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(x_data, {4, 6}, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 4, 6, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + Scope scope_; +}; + +TEST_F(ScaleAndTranslateGradTest, Works) { TestResize(); } + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD index cf65fe1ab99b49207a64e86310178141b30d07d7..e9838d9aba6554b40082187057851e9c896f8352 100644 --- a/tensorflow/cc/profiler/BUILD +++ b/tensorflow/cc/profiler/BUILD @@ -10,7 +10,7 @@ tf_cuda_cc_test( name = "profiler_test", srcs = ["profiler_test.cc"], tags = [ - "noguitar", # b/77649654 + "nogpu", # b/77649654 ], deps = [ ":profiler", diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 85d3dd01fa51b3c3ba6fcbf5faac03f1ff5630e2..10f7abf09e925c0c31cfd595ecee4605f189476f 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" @@ -42,9 +43,28 @@ auto* load_latency = monitoring::Counter<1>::New( "/tensorflow/cc/saved_model/load_latency", "Latency in microseconds for SavedModels that were successfully loaded.", "model_path"); +auto* load_latency_by_stage = monitoring::Sampler<2>::New( + { + "/tensorflow/cc/saved_model/load_latency_by_stage", // metric name + "Distribution of wall time spent (in microseconds) in each stage " + "(restore graph from disk, run init graph op, etc) when loading the " + "model", + "model_path", + "stage", + }, + // Scale of 10, power of 1.8 with bucket count 33 (~20 minutes). + monitoring::Buckets::Exponential(10, 1.8, 33)); + constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; +uint64 GetLatencyMicroseconds(const uint64 start_microseconds) { + const uint64 end_microseconds = Env::Default()->NowMicros(); + // Avoid clock skew. + if (end_microseconds < start_microseconds) return 0; + return end_microseconds - start_microseconds; +} + Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr* session) { @@ -242,6 +262,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, const string& export_dir, const std::unordered_set& tags, SavedModelBundle* const bundle) { + const uint64 read_start_microseconds = Env::Default()->NowMicros(); TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags, &bundle->meta_graph_def)); @@ -256,12 +277,23 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().filename_tensor_name(), asset_file_defs, bundle->session.get())); + // Record walltime spent in restoring graph from disk, but postpone metric + // increments until graph init finishes. + const uint64 restore_graph_walltime = + GetLatencyMicroseconds(read_start_microseconds); + + const uint64 graph_init_start_microseconds = Env::Default()->NowMicros(); string init_op_name; TF_RETURN_IF_ERROR( GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name)); TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def, asset_file_defs, bundle->session.get(), init_op_name)); + load_latency_by_stage->GetCell(export_dir, "restore_graph") + ->Add(restore_graph_walltime); + // Record wall time spent in init op. + load_latency_by_stage->GetCell(export_dir, "init_graph") + ->Add(GetLatencyMicroseconds(graph_init_start_microseconds)); return Status::OK(); } @@ -275,16 +307,10 @@ Status LoadSavedModel(const SessionOptions& session_options, const uint64 start_microseconds = Env::Default()->NowMicros(); const Status status = LoadSavedModelInternal(session_options, run_options, export_dir, tags, bundle); - const uint64 load_latency_microsecs = [&]() -> uint64 { - const uint64 end_microseconds = Env::Default()->NowMicros(); - // Avoid clock skew. - if (end_microseconds < start_microseconds) return 0; - return end_microseconds - start_microseconds; - }(); auto log_and_count = [&](const string& status_str) { LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ") << " }; Status: " << status_str << ". Took " - << load_latency_microsecs << " microseconds."; + << GetLatencyMicroseconds(start_microseconds) << " microseconds."; load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); }; if (status.ok()) { @@ -292,7 +318,8 @@ Status LoadSavedModel(const SessionOptions& session_options, } else { log_and_count(kLoadAttemptFail); } - load_latency->GetCell(export_dir)->IncrementBy(load_latency_microsecs); + load_latency->GetCell(export_dir) + ->IncrementBy(GetLatencyMicroseconds(start_microseconds)); return status; } diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index 23e9dc40d23899b9cef168c9128b6d8ed1be3ee9..eeb910178902ca883ed211379ba3f188c139f92e 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -124,7 +124,9 @@ Status GetVariableNameToTensorMap( return Status::OK(); } std::vector variable_names; + variable_names.reserve(variable_names_set.size()); std::vector tensor_names; + tensor_names.reserve(variable_names_set.size()); for (const string& node_name : variable_names_set) { variable_names.push_back(node_name); NodeDef* node_def = name_to_node_map.at(node_name); diff --git a/tensorflow/compat_template.__init__.py b/tensorflow/compat_template.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1833b6a65eef9baa2e92a13d9c4d44b79620de2f --- /dev/null +++ b/tensorflow/compat_template.__init__.py @@ -0,0 +1,43 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bring in all of the public TensorFlow interface into this module.""" + +from __future__ import absolute_import as _absolute_import +from __future__ import division as _division +from __future__ import print_function as _print_function + +import os as _os + +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import + +# API IMPORTS PLACEHOLDER + +from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v2.estimator')) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v2.keras')) + +# We would like the following to work for fully enabling 2.0 in a 1.0 install: +# +# import tensorflow.compat.v2 as tf +# tf.enable_v2_behavior() +# +# This make this one symbol available directly. +from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py index 7df80ec01245a7fe820c79d5879458c4cd0a93cb..9549a71c41a0ba2aac58abd8cfb182aa4eaf3b4f 100644 --- a/tensorflow/compat_template_v1.__init__.py +++ b/tensorflow/compat_template_v1.__init__.py @@ -23,12 +23,15 @@ import os as _os # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, - child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) - -# API IMPORTS PLACEHOLDER - + child_package_str=( + 'tensorflow_estimator.python.estimator.api._v1.estimator')) +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow.python.keras.api._v1.keras')) from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index e0ac7130a64d3928c39440c0e10a2d2e1990b9cd..d016632da2a9d7c2c2f81c02dd573787a0502923 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -129,7 +129,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type)); std::vector dim_vars; string dim_sizes, indices; - if (xla::ShapeUtil::Rank(shape) == 0 || + if (shape.rank() == 0 || (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) { dim_sizes = "[1]"; indices = "[0]"; @@ -178,7 +178,7 @@ Status GenArgMethods(const tf2xla::Config& config, TF_RETURN_IF_ERROR( AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); const string code = R"( - void set_arg{{NAME}}_data(void* data) { + void set_arg{{NAME}}_data(const void* data) { set_arg_data({{I}}, data); } {{TYPE}}* arg{{NAME}}_data() { @@ -384,8 +384,9 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, // calling HloProfilePrinter::profile_counters_size. const string assign_profile_counters_size = opts.gen_hlo_profile_printer_data - ? "data->set_profile_counters_size(" - "data->hlo_profile_printer_data()->profile_counters_size());" + ? "set_static_data_profile_counters_size(data, " + "get_static_data_hlo_profile_printer_data(data)->" + "profile_counters_size());" : ""; // Use a poor-man's text templating mechanism; first populate the full header @@ -449,7 +450,7 @@ extern "C" void {{ENTRY}}( // arg bytes aligned: {{ARG_BYTES_ALIGNED}} // temp bytes total: {{TEMP_BYTES_TOTAL}} // temp bytes aligned: {{TEMP_BYTES_ALIGNED}} -class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { +class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = {{ARG_NUM}}; @@ -464,16 +465,17 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->set_raw_function({{ENTRY}}); - data->set_buffer_infos(BufferInfos()); - data->set_num_buffers(kNumBuffers); - data->set_arg_index_table(ArgIndexToBufferIndex()); - data->set_num_args(kNumArgs); - data->set_result_index(kResultIndex); - data->set_arg_names(StaticArgNames()); - data->set_result_names(StaticResultNames()); - data->set_program_shape(StaticProgramShape()); - data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); + set_static_data_raw_function(data, {{ENTRY}}); + set_static_data_buffer_infos(data, BufferInfos()); + set_static_data_num_buffers(data, kNumBuffers); + set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); + set_static_data_num_args(data, kNumArgs); + set_static_data_result_index(data, kResultIndex); + set_static_data_arg_names(data, StaticArgNames()); + set_static_data_result_names(data, StaticResultNames()); + set_static_data_program_shape(data, StaticProgramShape()); + set_static_data_hlo_profile_printer_data( + data, StaticHloProfilePrinterData()); {{ASSIGN_PROFILE_COUNTERS_SIZE}} return data; }(); diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index a2cdab5d1a8e72504ca11b789287d4efd07a59e9..35994fc785d3e1d5e883c49bec96de315e189d2e 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -59,7 +59,7 @@ namespace bar { // arg bytes aligned: 192 // temp bytes total: 126 // temp bytes aligned: 320 -class MyClass : public tensorflow::XlaCompiledCpuFunction { +class MyClass final : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = 2; @@ -74,16 +74,17 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->set_raw_function(entry_point); - data->set_buffer_infos(BufferInfos()); - data->set_num_buffers(kNumBuffers); - data->set_arg_index_table(ArgIndexToBufferIndex()); - data->set_num_args(kNumArgs); - data->set_result_index(kResultIndex); - data->set_arg_names(StaticArgNames()); - data->set_result_names(StaticResultNames()); - data->set_program_shape(StaticProgramShape()); - data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); + set_static_data_raw_function(data, entry_point); + set_static_data_buffer_infos(data, BufferInfos()); + set_static_data_num_buffers(data, kNumBuffers); + set_static_data_arg_index_table(data, ArgIndexToBufferIndex()); + set_static_data_num_args(data, kNumArgs); + set_static_data_result_index(data, kResultIndex); + set_static_data_arg_names(data, StaticArgNames()); + set_static_data_result_names(data, StaticResultNames()); + set_static_data_program_shape(data, StaticProgramShape()); + set_static_data_hlo_profile_printer_data( + data, StaticHloProfilePrinterData()); return data; }(); @@ -114,7 +115,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // with dim indices specifying which value. No bounds checking is performed // on dim indices. - void set_arg0_data(void* data) { + void set_arg0_data(const void* data) { set_arg_data(0, data); } float* arg0_data() { @@ -132,7 +133,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg_myfeed_data(void* data) { + void set_arg_myfeed_data(const void* data) { set_arg_data(0, data); } float* arg_myfeed_data() { @@ -150,7 +151,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg1_data(void* data) { + void set_arg1_data(const void* data) { set_arg_data(1, data); } tensorflow::int64* arg1_data() { @@ -256,7 +257,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static const xla::ProgramShapeProto* StaticProgramShape() { static const xla::ProgramShapeProto* kShape = []() { xla::ProgramShapeProto* proto = new xla::ProgramShapeProto; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52); + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 64); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index ce8e5ec8c96a2c3696f14b8eea206d648182ecb5..7f7b96428572705f30144e6c95cd4cf9c44ce2a3 100644 Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 64b861a73091642b03573543a5c55618bf33915d..7bac79ec062af7e790134286e34eda4e123e138a 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -50,7 +50,7 @@ def tfadd_with_ckpt(out_dir): y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') - init_op = variables.initialize_all_variables() + init_op = variables.global_variables_initializer() saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) with session.Session() as sess: sess.run(init_op) @@ -65,7 +65,7 @@ def tfadd_with_ckpt_saver(out_dir): y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') - init_op = variables.initialize_all_variables() + init_op = variables.global_variables_initializer() saver = saver_lib.Saver(name='abcprefix', write_version=saver_pb2.SaverDef.V1) with session.Session() as sess: sess.run(init_op) diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 2dc3e8c9113b37bf9d575ad66783f4ab49478af4..2abe3e29b78dbbe719637b13418704acc213d050 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -207,7 +207,7 @@ def tf_library( # # Note that setting the local=1 attribute on a *test target* causes the # test infrastructure to skip that test. However this is a genrule, not - # a test target, and runs with --genrule_strategy=forced_forge, meaning + # a test target, and runs with --strategy=Genrule=forced_forge, meaning # the local=1 attribute is ignored, and the genrule is still run. # # https://www.bazel.io/versions/master/docs/be/general.html#genrule @@ -283,7 +283,7 @@ def tf_library( ) # Variables used for gen_test and gen_benchmark. - cpp_class_split = cpp_class.rsplit("::", maxsplit = 2) + cpp_class_split = cpp_class.rsplit("::", 2) if len(cpp_class_split) == 1: no_ns_name = cpp_class_split[0] else: diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index d548de8c44285f6d21dd778db464a31e1b19645b..0b6ab7e723d6e3a55da2f1c30b75f44cbdaa75bb 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -136,6 +136,10 @@ int main(int argc, char** argv) { tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); + if (argc > 1 && absl::string_view(argv[1]) == "--help") { + std::cerr << usage << "\n"; + return 0; + } bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); QCHECK(parsed_flags_ok) << "\n" << usage; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index be91ed4f432b1890c22900f293fd4196e5c9d970..3cae081ce7c78226390a82d222d57ac653c14321 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -76,6 +76,7 @@ cc_library( srcs = ["xla_cpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep ":flags", ":jit_compilation_passes", ":xla_device", @@ -95,6 +96,7 @@ cc_library( srcs = ["xla_gpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", @@ -104,6 +106,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) @@ -172,12 +175,22 @@ cc_library( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:stream_pool", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", @@ -198,6 +211,7 @@ cc_library( "//tensorflow/core/kernels/data:prefetch_dataset_op", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", ], ) @@ -512,6 +526,7 @@ cc_library( "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", @@ -610,6 +625,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:scope", @@ -622,15 +638,16 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:test", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "//tensorflow/core/grappler/optimizers/data:graph_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 48a23a4c1711ac88a329723c46559112d5a39dbd..390ffa694b6f127544d92f3024a02d877556aacd 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/jit/node_matchers.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 0562838f628c66b1eb03af9d2a5139c01dca31c5..0ef0d3db8c16e4b3f78d29aad5a2ae75a81d96f6 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -20,7 +20,10 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/hash/hash.h" @@ -222,29 +225,40 @@ class NotPredicate : public Predicate { std::array operands_; }; -// Represents an infinite list of predicates. +// Represents the liveness of an induction variable. For users inside the loop +// this represents the "current" liveness of the induction variable. For users +// outside the loop it represents the "last" liveness of the induction variable. // -// An AndRecurrence with start = S and step = X is printed as {S,&,X} and stands -// for the list of predicates: +// More concretely, an and recurrence {S,&,X} represents the liveness of V +// in the following graph: // -// S, S & GenSym(X,1), S & GenSym(X,1) & GenSym(X,2), ... +// V = Merge(S', V_NextIt) +// V = Op(V, X') +// V_NextIt = NextIteration(V) // -// where GenSym(, ) renames every SymbolPredicate in -// by appending to it, in effect creating a "fresh" symbol. -// This means {P,&,Q} is not equal to "P on the first iteration; P&Q on -// subsequent iterations". +// where Predicate(S') = S and Predicate(X') = X. +// +// `X` may contain symbolic predicates and the operations corresponding to these +// symbolic predicates are either in frame `loop` or outside it. The symbols +// that are inside frame `loop` are loop variant (i.e. can have different +// liveness in each loop iteration) and the symbols that are outside frame +// `loop` are loop invariant (i.e. have the same liveness across all +// iterations). class AndRecurrencePredicate : public Predicate { public: - explicit AndRecurrencePredicate(Predicate* start, Predicate* step) - : Predicate(HashPredicateSequence(Kind::kAndRecurrence, {start, step})), - operands_({start, step}) {} + explicit AndRecurrencePredicate(Predicate* start, Predicate* step, + std::vector frame) + : Predicate(Hash(start, step, frame)), + operands_({start, step}), + frame_(std::move(frame)) {} Predicate* start() const { return operands_[0]; } Predicate* step() const { return operands_[1]; } + absl::Span frame() const { return frame_; } string ToString() const override { return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(), - "}"); + "}<", absl::StrJoin(frame(), ";"), ">"); } Kind kind() const override { return Kind::kAndRecurrence; } @@ -255,6 +269,17 @@ class AndRecurrencePredicate : public Predicate { private: std::array operands_; + std::vector frame_; + + static int64 Hash(Predicate* start, Predicate* step, + const std::vector& frame) { + uint64 frame_hash = 0; + for (const string& sub_frame : frame) { + frame_hash = Hash64Combine(Hash64(sub_frame), frame_hash); + } + return Hash64Combine( + HashPredicateSequence(Kind::kAndRecurrence, {start, step}), frame_hash); + } }; // Represents an uninterpreted symbol in a logical predicate. @@ -281,7 +306,7 @@ class SymbolPredicate : public Predicate { // "tensor_id() is live and evaluates to true". // // If `must_be_true()` is false then this SymbolPredicate represents the - // proposition "tensor_id() is live (and may evalutate to any value)" + // proposition "tensor_id() is live (and may evaluate to any value)" TensorId tensor_id() const { return tensor_id_; } bool must_be_true() const { return must_be_true_; } @@ -333,34 +358,58 @@ class PredicateFactory { } Predicate* MakeNotPredicate(Predicate* pred) { - SignatureForNot signature = pred; - auto it = interned_not_instances_.find(signature); - if (it == interned_not_instances_.end()) { - std::unique_ptr new_pred = Make(pred); - Predicate* new_pred_ptr = new_pred.get(); - interned_not_instances_.emplace(signature, std::move(new_pred)); - return new_pred_ptr; - } else { - return it->second.get(); + auto it = make_not_predicate_cache_.find(pred); + if (it != make_not_predicate_cache_.end()) { + return it->second; } + + Predicate* result = MakeNotPredicateImpl(pred); + + bool insert_successful = + make_not_predicate_cache_.insert({pred, result}).second; + (void)insert_successful; + DCHECK(insert_successful); + + return result; } - Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step) { - auto it = interned_and_rec_instances_.find({start, step}); + Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step, + std::vector frame) { + SignatureForAndRec signature(start, step, std::move(frame)); + auto it = interned_and_rec_instances_.find(signature); if (it != interned_and_rec_instances_.end()) { return it->second.get(); } - std::unique_ptr new_pred = - Make(start, step); + std::unique_ptr new_pred = Make( + std::get<0>(signature), std::get<1>(signature), std::get<2>(signature)); Predicate* new_pred_ptr = new_pred.get(); - CHECK(interned_and_rec_instances_ - .emplace(SignatureForAndRec(start, step), std::move(new_pred)) - .second); + bool inserted = + interned_and_rec_instances_.emplace(signature, std::move(new_pred)) + .second; + (void)inserted; + DCHECK(inserted); return new_pred_ptr; } - Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) { + Status MakeSymbolPredicate(Node* node, int output_idx, bool must_be_true, + Predicate** predicate) { + TensorId tensor_id(node->name(), output_idx); + + bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL; + TF_RET_CHECK(!must_be_true || is_boolean_tensor); + + if (node->type_string() == "Const" && must_be_true) { + const TensorProto* proto = nullptr; + TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto)); + + Tensor tensor(proto->dtype()); + TF_RET_CHECK(tensor.FromProto(*proto)); + + *predicate = tensor.scalar()() ? MakeTrue() : MakeFalse(); + return Status::OK(); + } + SignatureForSymbol signature = {tensor_id, must_be_true}; auto it = interned_symbol_instances_.find(signature); if (it == interned_symbol_instances_.end()) { @@ -369,16 +418,63 @@ class PredicateFactory { Predicate* new_pred_ptr = new_pred.get(); interned_symbol_instances_.emplace(std::move(signature), std::move(new_pred)); - return new_pred_ptr; + *predicate = new_pred_ptr; } else { - return it->second.get(); + *predicate = it->second.get(); } + + return Status::OK(); } Predicate* MakeTrue() { return MakeAndPredicate({}); } Predicate* MakeFalse() { return MakeOrPredicate({}); } + ~PredicateFactory() { + DCHECK_EQ(stack_depth_, 0) << "Unnested IncrementStackDepth?"; + } + private: + Predicate* MakeNotPredicateImpl(Predicate* pred) { + IncrementStackDepth stack_frame(this); + if (!stack_frame.HasOverflowed()) { + if (Predicate* simplified = SimplifyUsingDeMorgan(pred)) { + return simplified; + } + + // ~~A => A + if (auto* not_pred = dynamic_cast(pred)) { + return not_pred->operand(); + } + } + + SignatureForNot signature = pred; + auto it = interned_not_instances_.find(signature); + if (it == interned_not_instances_.end()) { + std::unique_ptr new_pred = Make(pred); + Predicate* new_pred_ptr = new_pred.get(); + interned_not_instances_.emplace(signature, std::move(new_pred)); + return new_pred_ptr; + } else { + return it->second.get(); + } + } + + Predicate* SimplifyUsingDeMorgan(Predicate* pred) { + // ~(A & B & C & ...) => ~A | ~B | ~C | ~... + // ~(A | B | C | ...) -> ~A & ~B & ~C & ~... + Predicate::Kind kind = pred->kind(); + + if (kind == Predicate::Kind::kAnd || kind == Predicate::Kind::kOr) { + std::vector new_operands; + absl::c_transform(pred->GetOperands(), std::back_inserter(new_operands), + [&](Predicate* p) { return MakeNotPredicate(p); }); + return kind == Predicate::Kind::kOr ? MakeAndPredicate(new_operands) + : MakeOrPredicate(new_operands); + } + + return nullptr; + } + template std::unique_ptr Make(Args&&... args) { return std::unique_ptr( @@ -402,7 +498,8 @@ class PredicateFactory { using SignatureForAndOr = std::pair>; using SignatureForNot = Predicate*; - using SignatureForAndRec = std::pair; + using SignatureForAndRec = + std::tuple>; using SignatureForSymbol = std::pair; struct HashSignatureForAndOr { @@ -422,6 +519,36 @@ class PredicateFactory { } }; + // Used to limit recursion to avoid blowing up the stack and cap compile time. + class IncrementStackDepth { + public: + explicit IncrementStackDepth(PredicateFactory* parent) : parent_(parent) { + parent_->stack_depth_++; + } + + bool HasOverflowed() const { + const int kMaxStackDepth = 8; + return parent_->stack_depth_ >= kMaxStackDepth; + } + + ~IncrementStackDepth() { parent_->stack_depth_--; } + + private: + PredicateFactory* parent_; + }; + + // A cache for the MakeNotPredicate function. + // + // NB! This is *not* the same as `interned_not_instances_`. + // `interned_not_instances_` maps ensures pointer identity for `NotPredicate` + // instances, i.e., it ensures there at most one instance of Not(predicate) + // for any given predicate whereas `make_not_predicate_cache_` simply caches + // the result of the `MakeNotPredicate` function. The values in + // `interned_not_instances_` are always instance of `NotPredicate` whereas the + // values in `make_not_predicate_cache_` may not be (for instance it will map + // Not(Not(A)) to A). + absl::flat_hash_map make_not_predicate_cache_; + absl::flat_hash_map, HashSignatureForAndOr> interned_and_or_instances_; @@ -432,6 +559,7 @@ class PredicateFactory { absl::flat_hash_map, HashSignatureForSymbol> interned_symbol_instances_; + int stack_depth_ = 0; }; Predicate* PredicateFactory::MakeInternedAndOr( @@ -466,6 +594,13 @@ Predicate* PredicateFactory::MakeAndOrImpl( absl::Span operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; + + IncrementStackDepth stack_frame(this); + if (stack_frame.HasOverflowed()) { + return MakeInternedAndOr( + std::vector(operands.begin(), operands.end()), pred_kind); + } + Predicate::Kind other_pred_kind = is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd; absl::flat_hash_set simplified_ops_set; @@ -494,16 +629,31 @@ Predicate* PredicateFactory::MakeAndOrImpl( // Simplify "A&~A=>False" and "A|~A=>True". absl::flat_hash_set negated_ops; - for (Predicate* op : simplified_ops) { - if (op->kind() == Predicate::Kind::kNot) { - negated_ops.insert(dynamic_cast(*op).operand()); - } - } - for (Predicate* op : simplified_ops) { if (negated_ops.count(op)) { + // Simple case: + // + // A & ~A & ... == False + // A | ~A | ... == True return is_and ? MakeFalse() : MakeTrue(); } + + Predicate* negated_op = MakeNotPredicate(op); + if (negated_op->kind() == pred_kind) { + // Slightly more complicated case: + // + // (~A | ~B | ~C) & A & B & C & ... == + // ~(A & B & C) & (A & B & C) & ... == False + // + // (~A & ~B & ~C) | A | B | C | ... == + // ~(A | B | C) | (A | B | C) | ... == True + if (absl::c_all_of(negated_op->GetOperands(), [&](Predicate* p) { + return simplified_ops_set.contains(p); + })) { + return is_and ? MakeFalse() : MakeTrue(); + } + } + negated_ops.insert(negated_op); } // If all ops contain the same subop, then factor it out thanks to the @@ -619,6 +769,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { const Graph& graph_; absl::flat_hash_map predicate_map_; PredicateFactory predicate_factory_; + std::vector control_flow_info_; bool vlog_; }; @@ -661,9 +812,12 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); const Edge* pred_edge; TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); - Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( - TensorId(pred_edge->src()->name(), pred_edge->src_output()), - /*must_be_true=*/true); + + Predicate* true_switch; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + pred_edge->src(), pred_edge->src_output(), + /*must_be_true=*/true, &true_switch)); + Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch); // Output 0 is alive iff all inputs are alive and the condition is false. @@ -761,6 +915,23 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr; } + +Status GetFullFrame(const Node* n, absl::Span cfi_infos, + std::vector* frame) { + int depth = 0; + for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource(); + n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) { + frame->push_back(cfi_iter->frame_name); + + if (depth++ > 5000) { + return errors::Internal( + "Frame of depth > 5000: Probably malformed graph or a bug in " + "BuildControlFlowInfo"); + } + } + + return Status::OK(); +} } // namespace Status DeadnessAnalysisImpl::HandleMerge(Node* n, @@ -783,8 +954,10 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, if (has_unvisited_backedge) { // We're visiting this merge for the first time and it has an unvisited // backedge. - Predicate* input_data_pred = predicate_factory_.MakeSymbolPredicate( - TensorId(n->name(), 0), /*must_be_true=*/false); + Predicate* input_data_pred; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred)); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); return Status::OK(); @@ -825,8 +998,10 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, Predicate* start = predicate_factory_.MakeOrPredicate(non_recurrent_inputs); - Predicate* and_rec = - predicate_factory_.MakeAndRecurrencePredicate(start, step); + std::vector frame; + TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame)); + Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate( + start, step, std::move(frame)); SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit); return Status::OK(); } @@ -841,8 +1016,10 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, // acquire a dead signal from a _Send. std::vector input_preds; TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); - input_preds.push_back(predicate_factory_.MakeSymbolPredicate( - TensorId(n->name(), 0), /*must_be_true=*/false)); + Predicate* signal_is_alive; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + n, /*output_idx=*/0, /*must_be_true=*/false, &signal_is_alive)); + input_preds.push_back(signal_is_alive); SetPredicate(n, {0, Graph::kControlSlot}, predicate_factory_.MakeAndPredicate(input_preds), should_revisit); @@ -892,6 +1069,24 @@ Status DeadnessAnalysisImpl::Populate() { Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( absl::Span rpo) { + std::vector unreachable_nodes; + // Compute the loop structure of the graph. + TF_RETURN_IF_ERROR( + BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes)); + + // Do some opportunistic error checking: + if (!unreachable_nodes.empty()) { + if (unreachable_nodes.size() > 5) { + unreachable_nodes.erase(unreachable_nodes.begin() + 5, + unreachable_nodes.end()); + } + + return errors::InvalidArgument( + "Found unreachable nodes, most likely source and sink nodes not " + "connected: ", + absl::StrJoin(unreachable_nodes, ", ")); + } + // This an abstract interpretation over the deadness propagation semantics of // the graph executor. // diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 8a73101c184e6190921fd7729742922bd96f4bcf..16ee8f86d55c72785368ac2fd67635eba2fa7cd7 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -123,10 +123,9 @@ InductionVarInfo CreateInductionVariable(const Scope& root, Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1); Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10); Output loop_cond_expr = - ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value); - Output loop_cond = - ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); - ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); + ops::Less(root.WithOpName(prefix + "/cond"), iv.output, final_value); + ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, + loop_cond_expr); ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), latch.output_false); Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), @@ -140,7 +139,7 @@ InductionVarInfo CreateInductionVariable(const Scope& root, root.graph()->AddControlEdge(iv.output.node(), increment_by.node()); root.graph()->AddControlEdge(iv.output.node(), final_value.node()); - return {iv.output, loop_cond}; + return {iv.output, loop_cond_expr}; } InductionVarInfo CreateInductionVariable(const Scope& root, @@ -515,24 +514,27 @@ TEST(DeadnessAnalysisTest, Loop) { // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0 // produce the same deadness. But we're not that smart today. - EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], "{#true,&,*iv0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], "{#true,&,*iv1/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], "{#true,&,*iv2/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], + "{#true,&,*iv1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], + "{#true,&,*iv2/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "({#true,&,*iv1/cond:0} & {#true,&,*iv0/cond:0})"); + "({#true,&,*iv1/cond:0} & {#true,&,*iv0/cond:0})"); EXPECT_EQ(predicate_map[ControlOutputFor(add1)], - "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); + "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); } } TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); - InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0); Output dependent_iv0 = - CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0) .induction_var; Output dependent_iv1 = - CreateDependentLoopInvariantValue(root, "div1", "frame", iv.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0) .induction_var; Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1); @@ -549,13 +551,13 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], - "{#true,&,*iv0/cond:0}"); + "{#true,&,*iv0/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); } } @@ -595,32 +597,33 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); InductionVarInfo iv_outer = - CreateInductionVariable(root, "iv_outer", "frame", 0); + CreateInductionVariable(root, "iv_outer", "outer_loop", 0); + Output enter_constant_outer_loop = ops::internal::Enter( + root.WithOpName("constant_enter_outer_loop"), + ops::Const(root.WithOpName("constant"), 5), "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); ops::Switch inner_value(root.WithOpName("outer_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer.loop_cond); + enter_constant_outer_loop, iv_outer.loop_cond); InductionVarInfo iv_inner = CreateInductionVariable( - root, "iv_inner", "frame", - ops::internal::Enter(root.WithOpName("iv_inner/enter"), - inner_value.output_true, "frame_inner")); + root, "iv_inner", "inner_loop", inner_value.output_true); Output dependent_outer_iv0 = - CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", "frame", - iv_outer.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", + "outer_loop", iv_outer.loop_cond, 0) .induction_var; Output dependent_outer_iv1 = - CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", "frame", - iv_outer.loop_cond, 0) + CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", + "outer_loop", iv_outer.loop_cond, 0) .induction_var; - Output dependent_inner_iv0 = - CreateDependentLoopInvariantValue(root, "dependent_inner_iv0", "frame", - iv_inner.loop_cond, dependent_outer_iv0) - .induction_var; - Output dependent_inner_iv1 = - CreateDependentLoopInvariantValue(root, "dependent_inner_iv1", "frame", - iv_inner.loop_cond, dependent_outer_iv1) - .induction_var; + Output dependent_inner_iv0 = CreateDependentLoopInvariantValue( + root, "dependent_inner_iv0", "inner_loop", + iv_inner.loop_cond, dependent_outer_iv0) + .induction_var; + Output dependent_inner_iv1 = CreateDependentLoopInvariantValue( + root, "dependent_inner_iv1", "inner_loop", + iv_inner.loop_cond, dependent_outer_iv1) + .induction_var; Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0, dependent_inner_iv1); @@ -638,46 +641,50 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], - "{#true,&,*iv_outer/cond:0}"); + "{#true,&,*iv_outer/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)], - "{(*iv_outer/cond:0 & {#true,&,*iv_outer/cond:0}),&," - "*iv_inner/cond:0}"); + "{({#true,&,*iv_outer/cond:0} & " + "*iv_outer/cond:0),&,*iv_inner/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(*iv_inner/cond:0 & " + "iv_inner/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(*iv_inner/cond:0 & " + "iv_inner/iv:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," - "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(*iv_inner/cond:0 & " + "iv_inner/iv:0)}"); } } TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { Scope root = Scope::NewRootScope().ExitOnError(); - InductionVarInfo iv_outer_0 = - CreateInductionVariable(root, "iv_outer_0", "frame", 0); - ops::Switch inner_value_0(root.WithOpName("outer_0_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer_0.loop_cond); - InductionVarInfo iv_inner_0 = CreateInductionVariable( - root, "iv_inner_0", "frame", - ops::internal::Enter(root.WithOpName("iv_inner_0/enter"), - inner_value_0.output_true, "frame_inner")); - - InductionVarInfo iv_outer_1 = - CreateInductionVariable(root, "iv_outer_1", "frame", 1); - ops::Switch inner_init_value_1(root.WithOpName("outer_1_is_live"), - ops::Const(root.WithOpName("constant"), 5), - iv_outer_1.loop_cond); - InductionVarInfo iv_inner_1 = CreateInductionVariable( - root, "iv_inner_1", "frame", - ops::internal::Enter(root.WithOpName("iv_inner_1/enter"), - inner_init_value_1.output_true, "frame_inner")); - Output add0 = ops::Add(root.WithOpName("add0"), iv_inner_0.induction_var, - iv_inner_1.induction_var); + + std::array outer_iv; + std::array inner_iv; + + for (int i : {0, 1}) { + InductionVarInfo iv_outer = + CreateInductionVariable(root, "iv_outer", "outer_loop", 0); + Output enter_constant_outer_loop = ops::internal::Enter( + root.WithOpName("constant_enter_outer_loop"), + ops::Const(root.WithOpName("constant"), 5), "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); + ops::Switch inner_value(root.WithOpName("outer_is_live"), + enter_constant_outer_loop, iv_outer.loop_cond); + InductionVarInfo iv_inner = CreateInductionVariable( + root, "iv_inner", "inner_loop", inner_value.output_true); + + outer_iv[i] = iv_outer.induction_var; + inner_iv[i] = iv_inner.induction_var; + } + + Output add0 = ops::Add(root.WithOpName("add0"), inner_iv[0], inner_iv[1]); VLogGraphIfAsked(*root.graph()); @@ -692,21 +699,76 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_0.induction_var)], - "{#true,&,*iv_outer_0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_0.induction_var)], - "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," - "*iv_inner_0/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_1.induction_var)], - "{#true,&,*iv_outer_1/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_1.induction_var)], - "{(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," - "*iv_inner_1/cond:0}"); - EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "({(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," - "*iv_inner_1/cond:0} & " - "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," - "*iv_inner_0/cond:0})"); + EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])], + "{#true,&,*iv_outer/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])], + "{({#true,&,*iv_outer/cond:0} & " + "*iv_outer/cond:0),&,*iv_inner/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])], + "{#true,&,*iv_outer/cond_1:0}"); + EXPECT_EQ( + predicate_map[ControlOutputFor(inner_iv[1])], + "{({#true,&,*iv_outer/cond_1:0} & " + "*iv_outer/cond_1:0),&,*iv_inner/cond_1:0}"); + EXPECT_EQ( + predicate_map[ControlOutputFor(add0)], + "({({#true,&,*iv_outer/cond:0} & " + "*iv_outer/cond:0),&,*iv_inner/cond:0} & " + "{({#true,&,*iv_outer/cond_1:0} & " + "*iv_outer/cond_1:0),&,*iv_inner/cond_1:0})"); + } +} + +TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10); + InductionVarInfo iv_1 = CreateInductionVariable(root, "iv_1", "frame_1", 9); + + Output init = CreateSwitch(root, "init").output_true; + Output step = CreateSwitch(root, "step").output_true; + + std::array exits; + std::array next_iterations; + + for (int i : {0, 1}) { + Output init_enter = ops::internal::Enter( + root.WithOpName(absl::StrCat("init_enter_frame_", i)), init, + absl::StrCat("frame_", i), + ops::internal::Enter::Attrs().IsConstant(true)); + Output step_enter = ops::internal::Enter( + root.WithOpName(absl::StrCat("step_enter_frame_", i)), step, + absl::StrCat("frame_", i), + ops::internal::Enter::Attrs().IsConstant(true)); + + ops::Merge iv(root.WithOpName(absl::StrCat("expr_", i)), + {init_enter, init_enter}); + Output add = ops::Add(root.WithOpName(absl::StrCat("add_", i)), iv.output, + step_enter); + next_iterations[i] = ops::NextIteration( + root.WithOpName(absl::StrCat("expr_", i, "_next_iteration")), add); + EXPECT_TRUE( + root.graph() + ->UpdateEdge(next_iterations[i].node(), 0, iv.output.node(), 1) + .ok()); + exits[i] = ops::internal::Exit(root.WithOpName(absl::StrCat("exit_", i)), + iv.output); + } + + FixupSourceAndSinkEdges(root.graph()); + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], + predicate_map[ControlOutputFor(exits[1])]); + EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], ""); + EXPECT_NE(predicate_map[ControlOutputFor(exits[1])], ""); + + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], + predicate_map[ControlOutputFor(next_iterations[1])]); + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], ""); + EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[1])], ""); } } @@ -818,5 +880,82 @@ TEST(DeadnessAnalysisTest, RecvVsSwitchText) { EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)"); } +TEST(DeadnessAnalysisTest, DeMorgan) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output cond_0 = ops::Placeholder(root.WithOpName("cond_0"), DT_BOOL); + Output cond_1 = ops::Placeholder(root.WithOpName("cond_1"), DT_BOOL); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + + ops::Switch sw_0(root.WithOpName("switch_0"), value, cond_0); + ops::Switch sw_1(root.WithOpName("switch_1"), value, cond_1); + + Output and_0_1 = + ops::Add(root.WithOpName("and_0_1"), sw_0.output_true, sw_1.output_true); + + Output or_not0_not1 = ops::Merge(root.WithOpName("or_not0_not1"), + {sw_0.output_false, sw_1.output_false}) + .output; + + // Predicate(should_always_be_dead) = + // (A & B) & (~A | ~B) = (A & B) & ~(A & B) = False + Output should_always_be_dead = + ops::Add(root.WithOpName("should_always_be_dead"), and_0_1, or_not0_not1); + + // Predicate(should_always_be_dead) = + // (A & B) | (~A | ~B) = (A & B) | ~(A & B) = True + Output should_always_be_alive = + ops::Merge(root.WithOpName("should_always_be_alive"), + {and_0_1, or_not0_not1}) + .output; + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_dead)], "#false"); + EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_alive)], "#true"); +} + +TEST(DeadnessAnalysisTest, ConstantTrueSwitchCondition) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output constant_true = ops::Const(root.WithOpName("const_true"), true); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), value, constant_true); + + Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false); + Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true); + + FixupSourceAndSinkEdges(root.graph()); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#false"); + EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#true"); +} + +TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output constant_false = ops::Const(root.WithOpName("const_false"), false); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), value, constant_false); + + Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false); + Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true); + + FixupSourceAndSinkEdges(root.graph()); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#true"); + EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index f478832781cb1dc045d9163d4a6f5e5f64a8a705..d0d7a3f3785469acd79a83b6897668f94fc6ea2e 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -779,7 +779,8 @@ Status Encapsulator::Subgraph::RecordArg( if (inserted) { NodeDef arg_def; NodeDefBuilder builder( - absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp, + NodeDebugInfo(src_node->def())); DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); @@ -814,7 +815,8 @@ Status Encapsulator::Subgraph::RecordResult( if (inserted) { NodeDef ret_def; NodeDefBuilder builder( - absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp, + NodeDebugInfo(src_node->def())); DataType dtype = src_node->output_type(src_slot); builder.Attr("T", dtype); builder.Attr("index", ret_index); @@ -974,6 +976,7 @@ Status Encapsulator::Subgraph::AddHostComputes( } NodeDef host_compute_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", oc_subgraph_name, "_host_compute"), kHostComputeOp); @@ -1005,13 +1008,15 @@ Status Encapsulator::Subgraph::AddHostComputes( // subgraph. for (const auto& src_node : oc_subgraph.control_inputs) { Node* src_image = node_images.at(src_node); - graph_->AddControlEdge(src_image, host_compute); + graph_->AddControlEdge(src_image, host_compute, + /* allow_duplicates= */ true); } // Connect the _HostCompute node to its ancestor host compute nodes. for (const auto& ancestor_name : host_compute_ancestors) { Node* ancestor = host_compute_node[ancestor_name]; - graph_->AddControlEdge(ancestor, host_compute); + graph_->AddControlEdge(ancestor, host_compute, + /* allow_duplicates= */ true); } // Connect the consumers in the subgraph to the _HostCompute node. @@ -1028,7 +1033,8 @@ Status Encapsulator::Subgraph::AddHostComputes( // node. for (const auto& dst_node : oc_subgraph.control_outputs) { Node* dst_image = node_images.at(dst_node); - graph_->AddControlEdge(host_compute, dst_image); + graph_->AddControlEdge(host_compute, dst_image, + /* allow_duplicates= */ true); } } } @@ -1040,6 +1046,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp"); builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); builder.Device(device_); @@ -1055,7 +1062,8 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { VLOG(2) << "ConnectSequencerToCallNode"; - graph_out->AddControlEdge(sequencer_, call_node_); + graph_out->AddControlEdge(sequencer_, call_node_, + /* allow_duplicates= */ true); } } @@ -1214,7 +1222,8 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); NodeDef key_def; NodeDefBuilder builder( - absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder"); + absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder", + NodeDebugInfo(call_node_def_)); builder.Attr("dtype", DT_STRING); builder.Attr("shape", shape_proto); builder.Attr("_host_compute_call_node", call_node_def_.name()); @@ -1248,6 +1257,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( } NodeDef recv_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); @@ -1273,7 +1283,8 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( // completes. This has no effect on execution order but prevents the // RecvAtHost being pruned. TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); - graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_); + graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_, + true /* skip duplicates check */); return Status::OK(); } @@ -1303,6 +1314,7 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } NodeDef send_def; + // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_send"), kSendFromHostOp); @@ -1329,7 +1341,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( // subgraph completes. This has no effect on execution order but prevents the // RecvAtHost being pruned. TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); - graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_); + graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_, + /* allow_duplicates= */ true); return Status::OK(); } @@ -1439,7 +1452,8 @@ Status Encapsulator::CopySubgraphEdges( src_func_id == dst_func_id) { Graph* g = subgraphs_[src_func_id].GetGraph(); if (edge->IsControlEdge()) { - g->AddControlEdge(src_image, dst_image); + g->AddControlEdge(src_image, dst_image, + /* allow_duplicates= */ true); } else { g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); } @@ -1725,7 +1739,8 @@ Status Encapsulator::CopyEdgeToOutputGraph( if (edges_added ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1)) .second) { - graph_out->AddControlEdge(src_image, dst_image); + graph_out->AddControlEdge(src_image, dst_image, + /* allow_duplicates= */ true); } return Status::OK(); @@ -1754,7 +1769,8 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { const string& subgraph = ancestors.first; for (const string& ancestor : ancestors.second) { graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNode(), - subgraphs_[subgraph].GetCallNode()); + subgraphs_[subgraph].GetCallNode(), + /* allow_duplicates= */ true); } } return Status::OK(); @@ -1833,8 +1849,9 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port, // Add any Enter nodes required to bring the constant to the correct control // flow frame. while (!control_flow_info[src_node->id()].frame_name.empty()) { + NodeDebugInfo debug_info(*src_node); NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter", - options.op_registry()); + options.op_registry(), &debug_info); enter_builder.Attr("frame_name", control_flow_info[src_node->id()].frame_name); enter_builder.Attr("is_constant", true); @@ -2018,7 +2035,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( return errors::InvalidArgument( "Shape inference is not possible for outside_compilation " "SendFromHost node ", - send_node->name(), " because shape of node ", n->name(), + send_node->name(), " because shape of node ", + FormatNodeForError(*n), " will not be known at compilation time."); } } @@ -2047,8 +2065,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( return errors::Internal( "Internal assumption failed while rewriting an outside_compilation " "cluster that contains a while loop. Logic assumes back-edge is to " - "port 1 of a 2-input " - "Merge node."); + "port 1 of a 2-input Merge node."); } // Connect the existing edge to both inputs of the Merge node so that the // graph will be well-formed. @@ -2121,7 +2138,8 @@ Status CheckClusterDependencyForCycles( const string& ancestor, const string& successor, const std::unordered_map>& ancestors, const std::unordered_map& node_ancestors_map, - GraphCycles* cycle_detector, std::map* cycle_detector_map) { + GraphCycles* cycle_detector, + std::unordered_map* cycle_detector_map) { if (cycle_detector_map->find(ancestor) == cycle_detector_map->end()) { (*cycle_detector_map)[ancestor] = cycle_detector->NewNode(); } @@ -2165,7 +2183,7 @@ Status Encapsulator::FindClusterDependencies() { // We check that clusters are acyclic using this cycle detector. GraphCycles cycle_detector; // Map from cluster name to cycle detector node id. - std::map cycle_detector_map; + std::unordered_map cycle_detector_map; // Process the nodes in topologically-sorted order. std::vector nodes; GetReversePostOrder(*graph_in_, &nodes); @@ -2527,7 +2545,33 @@ Status EncapsulateSubgraphsPass::Run( std::vector* input_permutation, std::vector* output_permutation, NodeDef* node) { // Optimize the subgraph. - OptimizeGraph(flr, subgraph); + // Do not constant fold nodes that output DT_VARIANT type tensors. + // XLA does not support Const nodes of Variant type since it needs + // to know the original ops to be able to compile them to the relevant + // XLA form. + // TODO(srbs): This filter is a little conservative. E.g. a subgraph of + // the form: + // Const + // | + // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op + // | + // (Discard popped list) + // + // Would have been reduced to "Const -> Op" without this filter. + // However since we are only allowed to specify the filter at the "Node" + // level there is no good way to allow the above behavior. So we + // disallow any sort of constant folding on Variant nodes for now. + auto cf_consider_fn = [](const Node* n) { + for (const auto& output_arg : n->op_def().output_arg()) { + if (output_arg.type() == DT_VARIANT) { + return false; + } + } + return true; + }; + GraphOptimizer::Options graph_optimizer_options; + graph_optimizer_options.cf_consider_fn = cf_consider_fn; + OptimizeGraph(flr, subgraph, graph_optimizer_options); const int num_args = input_permutation->size(); std::vector const_args(num_args); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index de89be9a3555960dabe7bacd17226c15ae888ae6..1f8ec09e19c01d0a8b2a3761135ed53dfb2ad3b0 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -32,6 +34,8 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -299,7 +303,7 @@ REGISTER_OP("XlaHostCompute") .Attr("Toutputs: list(type) >= 0") .Attr("ancestors: list(string) >= 0") .Attr("key: string") - .Attr("shape_inference_graph: string = ''") + .Attr("shape_inference_graph: func") .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); @@ -510,12 +514,20 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, s = ConvertGraphDefToGraph(options, *graphdef, graph.get()); if (!s.ok()) return s; - s = PerformStaticShapeInferenceBeforeEncapsulation( - graph.get(), "_encapsulate", "_outside"); + s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get()); if (!s.ok()) return s; - s = PreprocessForEncapsulation(graph.get(), "_encapsulate", "_outside"); - if (!s.ok()) return s; + // Create FunctionLibraryRuntime. + SessionOptions session_options; + std::vector> devices; + TF_CHECK_OK(DeviceFactory::AddDevices( + session_options, "/job:localhost/replica:0/task:0", &devices)); + OptimizerOptions opts; + auto device_mgr = absl::make_unique(std::move(devices)); + auto pflr = absl::make_unique( + device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(), + opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions( @@ -542,7 +554,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, std::map{}}); } s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters, - graph_out.get(), lib_def.get()); + graph_out.get(), flr, lib_def.get()); if (!s.ok()) return s; GraphDef graphdef_out; @@ -550,6 +562,14 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, graphdef->Swap(&graphdef_out); *library = lib_def->ToProto(); + // Remove "_xla_inferred_shapes" attr. They are added by + // `PerformStaticShapeInferenceBeforeEncapsulation`. + for (FunctionDef& fdef : *library->mutable_function()) { + for (NodeDef& node_def : *fdef.mutable_node_def()) { + node_def.mutable_attr()->erase("_xla_inferred_shapes"); + } + } + return s; } @@ -901,18 +921,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, @@ -931,10 +955,11 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"c"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -948,16 +973,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), b2.opts() .WithName("E") - .WithControlInputs({recv, b}) + .WithControlInputs({recv}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), @@ -966,9 +993,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(a).Input(b); Node* call = - b2.opts().WithControlInputs({s}).FinalizeBuilder(&node_builder); + b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({call})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1022,14 +1049,16 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape1.opts()); + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } @@ -1037,33 +1066,45 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape2.opts()); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), shape2.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* g = Binary(e, ops::NodeOut(recv2, 0), + shape2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); Node* h = Binary(ops::NodeOut(recv2, 1), e, shape2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } + NameAttrList shape_inference_graph1, shape_inference_graph2; + shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1"); + shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, {{"I"}, "UnaryTest", - {"outside_compilation_O2_host_compute:outputs:0"}}, + {"outside_compilation_O2_host_compute:outputs:1"}}, {{"F"}, "BinaryTest", {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, @@ -1073,13 +1114,14 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { "XlaHostCompute", {"F:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O2"}, + {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"F"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", @@ -1088,13 +1130,15 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph1}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, - {{"i_0_retval_retval", "I:o:0"}}); + {{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"}, + {"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1105,19 +1149,22 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Binary(e, ops::NodeOut(recv2, 0), b2.opts() .WithName("G") @@ -1130,7 +1177,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); Node* send2 = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g, h}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* s = Sequencer(b2.opts() .WithName("F1_sequencer") @@ -1139,12 +1187,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(a).Input(b); - Node* call = b2.opts().WithControlInput(s).FinalizeBuilder(&node_builder); + Node* call = + b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder); - Binary(g, call, b2.opts().WithName("J")); + Binary(ops::NodeOut(call, 0), ops::NodeOut(call, 1), + b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } - TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } @@ -1196,7 +1245,9 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, - {"f_0_retval_retval:float", "d_0_retval_retval:float"}, {}, + {"e_0_retval_retval:float", "f_0_retval_retval:float", + "d_0_retval_retval:float"}, + {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1212,35 +1263,41 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, - {{"d_0_retval_retval", "D:o:0"}, {"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"d_0_retval_retval", "D:o:0"}, + {"f_0_retval_retval", "F:o:0"}}); *library_expected.add_function() = FunctionDefHelper::Create( - "F2", {"f_0_arg:float", "bridge_e_g_0_arg:float"}, - {"i_0_retval_retval:float", "g_0_retval_retval:float"}, {}, + "F2", {"e_0_arg:float", "f_0_arg:float", "d_0_arg:float"}, + {"g_0_retval_retval:float", "i_0_retval_retval:float"}, {}, { - {{"G"}, "BinaryTest", {"bridge_e_g_0_arg", "f_0_arg"}}, + {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, {{"I"}, "BinaryTest", {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {"G:o:0"}, - {{"Tinputs", absl::Span({DT_FLOAT})}, + {"d_0_arg", "G:o:0"}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"i_0_retval_retval", "I:o:0"}, {"g_0_retval_retval", "G:o:0"}}); + {{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1251,16 +1308,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* key_constant1 = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* recv1 = RecvAtHost( + ops::NodeOut(key_constant1, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s1 = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), "F1"); @@ -1268,29 +1327,33 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = - b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1); Node* key_constant2 = KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1", - {DT_FLOAT}, b2.opts()); - Node* h = Binary(ops::NodeOut(call1, 1), recv2, + Node* recv2 = RecvAtHost( + ops::NodeOut(key_constant2, 0), "F2", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* h = Binary(recv2, ops::NodeOut(recv2, 1), b2.opts() .WithName("H") .WithAttr("_encapsulate", "F2") .WithAttr("_outside", "O1")); - Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, - b2.opts()); + Node* send2 = + SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* s2 = Sequencer( b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), "F2"); NodeBuilder node_builder2("F2", "F2", lib_def.get()); - node_builder2.Input(call1).Input(e); + node_builder2.Input(call1) + .Input(ops::NodeOut(call1, 1)) + .Input(ops::NodeOut(call1, 2)); Node* call2 = b2.opts() - .WithControlInputs({s2, e, call1}) + .WithControlInputs({s2, call1}) .FinalizeBuilder(&node_builder2); - Binary(ops::NodeOut(call2, 1), call2, b2.opts().WithName("J")); + Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1326,8 +1389,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* h = Unary(g, b1.opts() .WithName("H") .WithAttr("_encapsulate", "F2") - .WithAttr("_outside", "O1") - .WithControlInput(e)); + .WithAttr("_outside", "O1")); Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2")); Binary(f, i, b1.opts().WithName("J")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); @@ -1358,10 +1420,12 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1380,10 +1444,12 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"i_0_retval_retval", "I:o:0"}}); @@ -1401,7 +1467,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), b2.opts() .WithName("E") - .WithControlInputs({recv1, b}) + .WithControlInputs({recv1}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, @@ -1413,7 +1479,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = - b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1); Node* key_constant2 = KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); @@ -1422,8 +1488,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { Node* h = Unary(recv2, b2.opts() .WithName("H") .WithAttr("_encapsulate", "F2") - .WithAttr("_outside", "O1") - .WithControlInput(e)); + .WithAttr("_outside", "O1")); Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, b2.opts()); @@ -1484,15 +1549,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {}, - {{"Tinputs", absl::Span({})}, + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1503,16 +1570,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1"); + b2.opts().WithName("F1_sequencer").WithControlInputs({send1, recv1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = @@ -1569,15 +1639,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {}, - {{"Tinputs", absl::Span({})}, + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({shape_proto_expected})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"D"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -1591,13 +1663,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = - RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts()); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithControlInput(recv1) - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithControlInput(recv1) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( @@ -1644,8 +1716,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1654,14 +1745,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { "XlaHostCompute", {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, - {"Toutputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1678,14 +1772,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1"); + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1722,8 +1819,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1736,14 +1852,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { "XlaHostCompute", {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, - {"Toutputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1760,7 +1879,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts().WithControlInput(e)); Node* s1 = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), @@ -1770,7 +1889,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* call1 = b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1813,22 +1932,45 @@ TEST(EncapsulateSubgraphsTest, FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape2.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, shape2.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts() .WithName("G") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, shape2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, + shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } + NameAttrList shape_inference_graph1; + shape_inference_graph1.set_name("_outside_compilation_shape_inference_F1_O1"); + NameAttrList shape_inference_graph2; + shape_inference_graph2.set_name("_outside_compilation_shape_inference_F1_O2"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1836,6 +1978,18 @@ TEST(EncapsulateSubgraphsTest, {{"H"}, "UnaryTest", {"outside_compilation_O2_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_inference_graph1}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0"}, @@ -1843,12 +1997,14 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O2"}, + {"shape_inference_graph", shape_inference_graph2}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -1856,30 +2012,39 @@ TEST(EncapsulateSubgraphsTest, GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, b2.opts()); - Node* g = Unary(recv, b2.opts() - .WithName("G") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O2") - .WithControlInput(e)); - Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, b2.opts()); - Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), - "F1"); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* g = Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* send2 = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("I")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1925,19 +2090,24 @@ TEST(EncapsulateSubgraphsTest, { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1945,6 +2115,18 @@ TEST(EncapsulateSubgraphsTest, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, {{"H"}, "UnaryTest", {"F:o:0"}}, + {{"outside_compilation_O2_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", NameAttrList()}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -1952,12 +2134,14 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -1968,27 +2152,33 @@ TEST(EncapsulateSubgraphsTest, Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); - Node* e = Unary(recv, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); - /*Node* g =*/Unary(a, b2.opts() - .WithName("G") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O2") - .WithControlInput(e)); - Node* s1 = Sequencer( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), - "F1"); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + /*Node* g =*/Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, recv2, send}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("I")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2039,19 +2229,24 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape1.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "h_0_retval_retval:float"}, {}, {{{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, @@ -2063,10 +2258,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"D:o:0"}, @@ -2074,9 +2270,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O2"}}, + {"_outside_compilation_subgraph", "O2"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {}}, {{"outside_compilation_O3_host_compute"}, "XlaHostCompute", @@ -2085,11 +2283,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O3"}, - {"shape_inference_graph", ""}, + {"shape_inference_graph", NameAttrList()}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O3"}}, + {"_outside_compilation_subgraph", "O3"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {}}}, - {{"h_0_retval_retval", "H:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -2100,23 +2301,27 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); + Node* recv1 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* e = Unary(recv1, b2.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); - Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", - {DT_FLOAT}, b2.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); Node* g = Unary(recv2, b2.opts() .WithName("G") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2") .WithControlInput(e)); - Node* recv3 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", - {DT_FLOAT}, b2.opts()); + Node* recv3 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); /*Node* i =*/Binary(recv3, e, b2.opts() .WithName("I") @@ -2131,7 +2336,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { node_builder1.Input(a).Input(b).ControlInput(s1); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("J")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2167,14 +2372,46 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); + Node* recv2 = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"e_0_retval_retval:float", "f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"a_0_arg"}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, + {"shapes", absl::Span({})}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}}, }, - {{"f_0_retval_retval", "F:o:0"}}); + {{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"}, + {"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -2183,15 +2420,26 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = Unary(recv, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); - node_builder1.Input(a).Input(b); + node_builder1.Input(a).Input(b).ControlInput(s); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G")); + Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -2236,20 +2484,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); Node* key_constant = KeyPlaceholder("F1", shape.opts()); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, shape.opts()); - Node* a = InputShaped(shape.opts().WithName("A")); - Node* c = Unary(a, shape.opts().WithName("C")); - Node* e = BinaryUnknownShape(c, recv, + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1), shape.opts() .WithName("E") .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + shape.opts().WithAttr(kXlaHasHostTransferAttrName, true)); TF_EXPECT_OK( AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } + NameAttrList shape_inference_graph; + shape_inference_graph.set_name("_outside_compilation_shape_inference_F1_O1"); *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {}, @@ -2262,15 +2512,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", - {"c:o:0"}, - {{"Tinputs", absl::Span({DT_FLOAT})}, + {"c_0_arg", "c:o:0"}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, + {"shape_inference_graph", shape_inference_graph}, {"shapes", absl::Span({})}, - {"_outside_compilation_subgraph", "O1"}}, + {"_outside_compilation_subgraph", "O1"}, + {"_xla_token_input_nodes", + absl::Span({"_xla_token_arg_node"})}}, {"c"}}, }, {{"f_0_retval_retval", "F:o:0"}}); @@ -2285,16 +2536,18 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { Node* key_constant = KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT}, b2.opts()); - Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0), + Node* recv = RecvAtHost( + ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithAttr(kXlaHasHostTransferAttrName, true)); + Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1), b2.opts() .WithName("E") - .WithControlInputs({recv, b}) + .WithControlInputs({recv}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O1")); Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, - b2.opts().WithControlInput(e)); + b2.opts().WithControlInput(e).WithAttr( + kXlaHasHostTransferAttrName, true)); Node* s = Sequencer( b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), @@ -2303,9 +2556,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(b).Input(c); Node* call = - b2.opts().WithControlInputs({s, c}).FinalizeBuilder(&node_builder); + b2.opts().WithControlInputs({s, b, c}).FinalizeBuilder(&node_builder); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); + Binary(a, call, b2.opts().WithName("G").WithControlInputs({call})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index 1f4b9c90a4ff0b1166cdb7b5942771b350740ef3..2264806d6bdabd9f26d9f83b681524399f996317 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -62,517 +62,6 @@ void ReplaceAttr(Node* n, const string& attr_name, const T& value) { n->AddAttr(attr_name, value); } -// Step 1a ~ 1d for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges to remove. We should not remove the edge while iterating. - std::vector edges_to_remove; - for (const Edge* e : g->edges()) { - if (!e->IsControlEdge()) { - continue; - } - - auto src_xla_computation = - GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_xla_computation = - GetStringAttr(*e->dst(), xla_computation_attr_name); - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - - if (!src_xla_computation && !dst_xla_computation) { - continue; - } else if (src_xla_computation && !dst_xla_computation) { - if (src_outside_compilation) { - // Case 1c: outside compilation to host computation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else if (!src_xla_computation && dst_xla_computation) { - if (dst_outside_compilation) { - // Case 1c: host computation control to outside compilation edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else { // src_xla_computation && dst_xla_computation - if (*src_xla_computation != *dst_xla_computation) { - if (src_outside_compilation && dst_outside_compilation) { - // Case 1b: outside compilation to outside compilation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } else if (src_outside_compilation && !dst_outside_compilation) { - // Case 1a: outside compilation to another XLA computaition control - // edge. - TF_RETURN_IF_ERROR(AppendToListAttr( - e->src(), kXlaConnectedToOtherXlaComputationAttrName, - *dst_xla_computation)); - } else if (!src_outside_compilation && dst_outside_compilation) { - // Case 1a: another XLA computaition to outside compilation control - // edge. - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaConnectedFromOtherXlaComputationAttrName, - *src_xla_computation)); - } - } - } - } - - for (auto e : edges_to_remove) { - g->RemoveEdge(e); - } - return Status::OK(); -} - -// Step 2 for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessXlaToXlaDataEdges(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges between XLA computations. Notice that we do not store `Edge*` - // directly because we remove some nodes while adding Identity nodes, and - // those Edge pointers might be invalidated. - struct EdgeInfo { - int dst_input, dst_node_id; - }; - std::vector edges; - for (const Edge* e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - auto src_xla_computation = - GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_xla_computation = - GetStringAttr(*e->dst(), xla_computation_attr_name); - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - if (!src_xla_computation || !dst_xla_computation) { - continue; - } - - if (*src_xla_computation != *dst_xla_computation) { - if (src_outside_compilation || dst_outside_compilation) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); - VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); - } - } - } - - // For each XLA -> XLA edge, add an Identity node between src and dst. - for (int i = 0; i < edges.size(); i++) { - Node* dst = g->FindNodeId(edges[i].dst_node_id); - const Edge* e; - TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); - Node* src = e->src(); - int src_output = e->src_output(), dst_input = e->dst_input(); - g->RemoveEdge(e); - - // Create Identity node, and connect it between `src` and `dst`. - string identity_node_name = - absl::StrCat("bridge_", src->name(), "_", dst->name()); - DataType dtype = src->output_type(src_output); - TF_ASSIGN_OR_RETURN(Node * identity_node, - BuildIdentityNode(g, identity_node_name, dtype, src, - /*requested_device=*/absl::nullopt)); - identity_node->AddAttr(kBridgeSourceNodeAttrName, src->name()); - g->AddEdge(src, src_output, identity_node, 0); - g->AddEdge(identity_node, 0, dst, dst_input); - - // Replace `e->dst()` because its input node changed. - NodeDef new_def = dst->def(); - *new_def.mutable_input(dst_input) = identity_node->name(); - TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); - - // Other edge in `edges` might have `e->dst()` as src or dst - // node. Before removing `e->dst()`, replace those edges with corresponding - // edges for `dst_replace_node`. - for (int j = i + 1; j < edges.size(); j++) { - if (edges[j].dst_node_id == edges[i].dst_node_id) { - edges[j].dst_node_id = dst_replace_node->id(); - } - } - } - return Status::OK(); -} - -// Step 3 for PreprocessForEncapsulation(). See comments of -// PreprocessForEncapsulation() for details. -Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Gather edges between outside compilation and host computation. Notice that - // we do not store `Edge*` directly because we remove some nodes while adding - // Identity nodes, and those Edge pointers might be invalidated. - struct EdgeInfo { - int dst_input, dst_node_id; - bool is_host_to_outside_compilation; - }; - std::vector edges; - for (const Edge* e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - if (e->src()->attrs().Find(xla_computation_attr_name) == nullptr && - e->dst()->attrs().Find(xla_computation_attr_name) != nullptr && - e->dst()->attrs().Find(outside_compilation_attr_name) != nullptr) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), - /*is_host_to_outside_compilation=*/true}); - VLOG(4) << "Host -> oc edge: " << e->DebugString(); - } else if (e->dst()->attrs().Find(xla_computation_attr_name) == nullptr && - e->src()->attrs().Find(xla_computation_attr_name) != nullptr && - e->src()->attrs().Find(outside_compilation_attr_name) != - nullptr) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), - /*is_host_to_outside_compilation=*/false}); - VLOG(4) << "Oc -> host edge: " << e->DebugString(); - } - } - - // Remove the edge from host to outside compilation. Add a placeholder as - // outside compilation node input. - std::map, Node*> placeholders; - for (int i = 0; i < edges.size(); i++) { - Node* dst = g->FindNodeId(edges[i].dst_node_id); - const Edge* e; - TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); - Node* src = e->src(); - int src_output = e->src_output(), dst_input = e->dst_input(); - g->RemoveEdge(e); - - // Find or create placeholder node. - string new_name = - edges[i].is_host_to_outside_compilation - ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output) - : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output); - auto placeholder_index = std::make_pair(src->name(), src_output); - auto iter = placeholders.find(placeholder_index); - Node* placeholder_node; - if (iter == placeholders.end()) { - NodeDefBuilder placeholder_builder(new_name, "Placeholder"); - placeholder_builder.Attr("dtype", src->output_type(src_output)); - if (edges[i].is_host_to_outside_compilation) { - placeholder_builder.Attr(kHostToOutsideCompilationOriginalNodeAttrName, - src->name()); - placeholder_builder.Attr(kHostToOutsideCompilationSrcOutputAttrName, - src_output); - // If this placeholder node is in outside compilation, we need to set - // `xla_computation_attr_name` and `outside_compilation_attr_name`. - string xla_computation_attr, outside_compilation_attr; - TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), xla_computation_attr_name, - &xla_computation_attr)); - TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), - outside_compilation_attr_name, - &outside_compilation_attr)); - placeholder_builder.Attr(xla_computation_attr_name, - xla_computation_attr); - placeholder_builder.Attr(outside_compilation_attr_name, - outside_compilation_attr); - } else { - placeholder_builder.Attr(kOutsideCompilationToHostOriginalNodeAttrName, - src->name()); - placeholder_builder.Attr(kOutsideCompilationToHostSrcOutputAttrName, - src_output); - } - NodeDef placeholder_def; - TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def)); - Status s; - placeholder_node = g->AddNode(placeholder_def, &s); - TF_RETURN_IF_ERROR(s); - placeholders[placeholder_index] = placeholder_node; - } else { - placeholder_node = iter->second; - } - g->AddEdge(placeholder_node, 0, dst, dst_input); - - // Replace `e->dst()` because its input node changed. - NodeDef new_def = dst->def(); - *new_def.mutable_input(dst_input) = placeholder_node->name(); - TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); - - // Other edge in `edges` might have `e->dst()` as src or dst - // node. Before removing `e->dst()`, replace those edges with corresponding - // edges for `dst_replace_node`. - for (int j = i + 1; j < edges.size(); j++) { - if (edges[j].dst_node_id == edges[i].dst_node_id) { - edges[j].dst_node_id = dst_replace_node->id(); - } - } - } - return Status::OK(); -} - -// Step 1 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -Status RemovePlaceholderBetweenOutsideCompilationAndHostComputation(Graph* g) { - // Gather all outside compilation to host computation nodes. - struct PlaceHolderNodeInfo { - Node* n; - bool is_host_to_oc; - }; - std::vector placeholder_nodes; - for (Node* n : g->nodes()) { - if (n->type_string() == "Placeholder") { - if (HasNodeAttr(n->def(), - kOutsideCompilationToHostOriginalNodeAttrName)) { - placeholder_nodes.push_back({n, false}); - } else if (HasNodeAttr(n->def(), - kHostToOutsideCompilationOriginalNodeAttrName)) { - placeholder_nodes.push_back({n, true}); - } - } - } - - // Remove the placeholder nodes, and reconnect original edge. - auto node_name_index = g->BuildNodeNameIndex(); - for (auto placeholder_iter : placeholder_nodes) { - Node* n = placeholder_iter.n; - - string node_name; - int node_src_output; - if (placeholder_iter.is_host_to_oc) { - TF_RETURN_IF_ERROR( - GetNodeAttr(n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, - &node_name)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), - kHostToOutsideCompilationSrcOutputAttrName, - &node_src_output)); - } else { - TF_RETURN_IF_ERROR( - GetNodeAttr(n->attrs(), kOutsideCompilationToHostOriginalNodeAttrName, - &node_name)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), - kOutsideCompilationToHostSrcOutputAttrName, - &node_src_output)); - } - auto iter = node_name_index.find(node_name); - if (iter == node_name_index.end()) { - return errors::Internal( - "Cannot find original node for oc -> host placeholder node ", - node_name); - } - - // Change all usage node to use the original node instead. - Node* original_node = iter->second; - std::vector control_edges; - std::vector data_edges; - for (auto e : n->out_edges()) { - if (e->IsControlEdge()) { - control_edges.push_back(e); - } else { - data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); - } - } - for (const Edge* e : control_edges) { - g->AddControlEdge(original_node, e->dst()); - g->RemoveEdge(e); - } - for (int i = 0; i < data_edges.size(); i++) { - Node* dst = data_edges[i].dst; - NodeDef new_def = dst->def(); - int dst_input = data_edges[i].dst_input; - *new_def.mutable_input(dst_input) = - absl::StrCat(original_node->name(), ":", node_src_output); - TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); - - const Edge* edge_to_replace = nullptr; - TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); - g->RemoveEdge(edge_to_replace); - g->AddEdge(original_node, node_src_output, replace_node, dst_input); - - // Other edges might have `dst` as dst node. Update those edges with - // `replace_node`. - for (int j = i + 1; j < data_edges.size(); j++) { - if (data_edges[j].dst == dst) { - data_edges[j].dst = replace_node; - } - } - - // Other placeholder node might have `dst` as original node. Update - // `node_name_index` with `replace_node`. - node_name_index[replace_node->name()] = replace_node; - } - - // Remove placeholder node. - g->RemoveNode(n); - } - return Status::OK(); -} - -// Step 2 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -Status RemoveIdentityBetweenDifferentXlaComputation(Graph* g) { - // Gather Identity nodes to remove. - std::vector bridge_nodes; - for (Node* n : g->nodes()) { - if (n->type_string() == "Identity" && - HasNodeAttr(n->def(), kBridgeSourceNodeAttrName)) { - bridge_nodes.push_back(n); - } - } - - // Remove the identity nodes, and reconnect the original edge. - for (int i = 0; i < bridge_nodes.size(); i++) { - Node* n = bridge_nodes[i]; - const Edge* src_edge = nullptr; - TF_RETURN_IF_ERROR(n->input_edge(0, &src_edge)); - - // Change all usage node to use the original node instead. - std::vector control_edges; - std::vector data_edges; - for (auto e : n->out_edges()) { - if (e->IsControlEdge()) { - control_edges.push_back(e); - } else { - data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); - } - } - for (const Edge* e : control_edges) { - g->AddControlEdge(src_edge->src(), e->dst()); - g->RemoveEdge(e); - } - for (int j = 0; j < data_edges.size(); j++) { - Node* dst = data_edges[j].dst; - NodeDef new_def = dst->def(); - int dst_input = data_edges[j].dst_input; - *new_def.mutable_input(dst_input) = - absl::StrCat(src_edge->src()->name(), ":", src_edge->src_output()); - TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); - - const Edge* edge_to_replace = nullptr; - TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); - g->RemoveEdge(edge_to_replace); - g->AddEdge(src_edge->src(), src_edge->src_output(), replace_node, - dst_input); - - // Other edges might have `dst` as dst node. Update those edges with - // `replace_node`. - for (int k = j + 1; k < data_edges.size(); k++) { - if (data_edges[k].dst == dst) { - data_edges[k].dst = replace_node; - } - } - - // The node we replaced might be in `bridge_nodes`. If so, update - // `bridge_nodes` to use the replaced node. - for (int k = i + 1; k < bridge_nodes.size(); k++) { - if (bridge_nodes[k] == dst) { - bridge_nodes[k] = replace_node; - } - } - } - - // Remove Identity node. - g->RemoveNode(n); - } - return Status::OK(); -} - -// Step 3 for `PostprocessForEncapsulation`. See comments of -// `PostprocessForEncapsulation` for details. -// We do not need to worry about removed nodes in step 1 and 2; -// `PreprocessForEncapsulation` will not record control dependencies for those -// remvoed nodes in the first place. -Status AddControlDependencies( - Graph* g, const std::unordered_map& cluster_node_names) { - auto node_name_index = g->BuildNodeNameIndex(); - - // Reconnect outside compilation to outside compilation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = - GetNodeAttr(n->attrs(), kXlaControlDependenciesAttrName, &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaControlDependenciesAttrName); - for (const string& control_input : control_deps) { - auto iter = node_name_index.find(control_input); - if (iter == node_name_index.end()) { - return errors::Internal("Cannot find original node for ", - control_input); - } - g->AddControlEdge(iter->second, n); - } - } - } - - // Reconnect outside compilation to XLA computation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = GetNodeAttr( - n->attrs(), kXlaConnectedToOtherXlaComputationAttrName, &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaConnectedToOtherXlaComputationAttrName); - for (const string& control_input : control_deps) { - auto iter = cluster_node_names.find(control_input); - if (iter == cluster_node_names.end()) { - return errors::Internal("Cannot find cluster node for ", - control_input); - } - auto iter2 = node_name_index.find(iter->second); - if (iter2 == node_name_index.end()) { - return errors::Internal("Cannot find cluster node for ", - iter->second); - } - g->AddControlEdge(n, iter2->second); - } - } - } - - // Reconnect XLA computation to outside compilation control edge. - for (Node* n : g->nodes()) { - std::vector control_deps; - Status s = - GetNodeAttr(n->attrs(), kXlaConnectedFromOtherXlaComputationAttrName, - &control_deps); - if (!s.ok()) { - if (s.code() != error::NOT_FOUND) { - return s; - } else { - continue; - } - } else { - n->ClearAttr(kXlaConnectedFromOtherXlaComputationAttrName); - for (const string& control_input : control_deps) { - auto iter = cluster_node_names.find(control_input); - if (iter == cluster_node_names.end()) { - return errors::Internal("Cannot find cluster node for ", - control_input); - } - auto iter2 = node_name_index.find(iter->second); - if (iter2 == node_name_index.end()) { - return errors::Internal("Cannot find cluster node for ", - iter->second); - } - g->AddControlEdge(iter2->second, n); - } - } - } - - return Status::OK(); -} - // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of // `PreprocessEdgesBetweenOutsideCompilations` for details. Status PreprocessControlEdgesBetweenOutsideCompilations( @@ -811,20 +300,6 @@ Status PostprocessControlEdgesBetweenOutsideCompilations( const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; -const char kXlaConnectedToOtherXlaComputationAttrName[] = - "_xla_connected_to_other_xla_computation"; -const char kXlaConnectedFromOtherXlaComputationAttrName[] = - "_xla_connected_from_other_xla_computation"; -const char kXlaControlDependenciesAttrName[] = "_xla_control_dependencies"; -const char kBridgeSourceNodeAttrName[] = "_xla_bridge_src"; -const char kOutsideCompilationToHostOriginalNodeAttrName[] = - "_xla_oc_to_host_node_name"; -const char kOutsideCompilationToHostSrcOutputAttrName[] = - "_xla_oc_to_host_src_output"; -const char kHostToOutsideCompilationOriginalNodeAttrName[] = - "_xla_host_to_oc_node_name"; -const char kHostToOutsideCompilationSrcOutputAttrName[] = - "_xla_host_to_oc_src_output"; const char kXlaConnectedToXlaComputationAttrName[] = "_xla_connected_to_xla_computation"; const char kXlaConnectedFromXlaComputationAttrName[] = @@ -835,32 +310,7 @@ const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output"; const char kXlaControlDependenciesWithinXlaClusterAttrName[] = "_xla_control_dependencies_within_xla_cluster"; -Status PerformStaticShapeInferenceBeforeEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - // Find all outside compilation to XLA computation data edges. - std::unordered_set outside_compilation_send_nodes; - for (auto e : g->edges()) { - if (e->IsControlEdge()) { - continue; - } - - auto src_computation = GetStringAttr(*e->src(), xla_computation_attr_name); - auto dst_computation = GetStringAttr(*e->dst(), xla_computation_attr_name); - if (!src_computation || !dst_computation || - *src_computation != *dst_computation) { - continue; - } - - auto src_outside_compilation = - GetStringAttr(*e->src(), outside_compilation_attr_name); - auto dst_outside_compilation = - GetStringAttr(*e->dst(), outside_compilation_attr_name); - if (src_outside_compilation && !dst_outside_compilation) { - outside_compilation_send_nodes.insert(e->src()); - } - } - +Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) { // Perform shape inference. std::map arg_shapes; GraphShapeInfo shape_info; @@ -868,55 +318,21 @@ Status PerformStaticShapeInferenceBeforeEncapsulation( InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); // Add attribute for output shapes. - for (Node* n : outside_compilation_send_nodes) { - auto iter = shape_info.find(n->name()); - if (iter == shape_info.end()) { - continue; - } - + auto node_name_index = g->BuildNodeNameIndex(); + for (auto iter : shape_info) { std::vector output_shapes; - std::transform(iter->second.begin(), iter->second.end(), + std::transform(iter.second.begin(), iter.second.end(), std::back_inserter(output_shapes), [](const InferredShape& inferred_shape) { return inferred_shape.shape; }); + Node* n = node_name_index[iter.first]; n->AddAttr(kXlaInferredShapesAttrName, output_shapes); } return Status::OK(); } -Status PreprocessForEncapsulation(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name) { - TF_RETURN_IF_ERROR(ProcessControlEdges(g, xla_computation_attr_name, - outside_compilation_attr_name)); - TF_RETURN_IF_ERROR(ProcessXlaToXlaDataEdges(g, xla_computation_attr_name, - outside_compilation_attr_name)); - TF_RETURN_IF_ERROR(ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( - g, xla_computation_attr_name, outside_compilation_attr_name)); - return Status::OK(); -} - -Status PostprocessForEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters) { - // The `node` pointer in `XlaClusterInfo` might be invalidated in step 1/2, - // but the node name won't change. Record cluster node name for - // `AddControlDependencies`. - std::unordered_map cluster_node_names; - for (const auto& iter : clusters) { - cluster_node_names[iter.first] = iter.second.node->name(); - } - - TF_RETURN_IF_ERROR( - RemovePlaceholderBetweenOutsideCompilationAndHostComputation(g)); - TF_RETURN_IF_ERROR(RemoveIdentityBetweenDifferentXlaComputation(g)); - TF_RETURN_IF_ERROR(AddControlDependencies(g, cluster_node_names)); - return Status::OK(); -} - Status PreprocessEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { // Remove edges from source node to outside compilation nodes, and edges diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index e363bc5754ac395bae262dc67a780a0173efaf5e..c9f16d14168163e11bb19092f566f1de8724aca3 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -27,51 +27,13 @@ namespace tensorflow { // a list of PartialTensorShape objects. extern const char kXlaInferredShapesAttrName[]; -// Infer output shapes for outside compilation nodes which have output data -// edges to XLA computation nodes. These shapes will be used later by XLA -// compiler as output shapes of the outside compilation's XlaHostCompute op. -// XLA computation nodes will be mark by attr `xla_computation_attr_name`; -// outside compilation nodes will be marked by both attr -// `xla_computation_attr_name` and `outside_compilation_attr_name`. -// -// Those outside compilation nodes will be marked with attribute -// `kXlaInferredShapesAttrName`. +// Infers output shapes for all nodes in graph `g`. The output shapes will be +// stored in node attribute `kXlaInferredShapesAttrName`. // // We have to perform shape inference before encapsulation because after // encapsulation, some nodes will be encapsulated into function call, and shape // inference does not handle function call at the moment. -Status PerformStaticShapeInferenceBeforeEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name); - -// Attribute indicating that some ops in other XLA computation has control -// dependency on this node. Attribute value will be a list of string (XLA -// computation names). -extern const char kXlaConnectedToOtherXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependency on some ops in -// other XLA computation. Attribute value will be a list of string (XLA -// computation names). -extern const char kXlaConnectedFromOtherXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependencies on some other -// nodes. Attribute value will be a list of string (node names). -extern const char kXlaControlDependenciesAttrName[]; - -// Attribute indicating that this is an Identity node added to act as a bridge -// between different XLA computations. Attribute value will be string (source -// node name). -extern const char kBridgeSourceNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an outside compilation node. Attribute value will be -// string (original input node name). -extern const char kOutsideCompilationToHostOriginalNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an outside compilation node. Attribute value will be -// int (src_output for original edge). -extern const char kOutsideCompilationToHostSrcOutputAttrName[]; +Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g); // Attribute indicating that some ops in this node's XLA computation has control // dependency on this node. Attribute value will always be "true". @@ -81,16 +43,6 @@ extern const char kXlaConnectedToXlaComputationAttrName[]; // this node's XLA computation. Attribute value will always be "true". extern const char kXlaConnectedFromXlaComputationAttrName[]; -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for an host node. Attribute value will be string -// (original input node name). -extern const char kHostToOutsideCompilationOriginalNodeAttrName[]; - -// Attribute indicating that this is an Placeholder node added to act as a -// temporary input node for a host node. Attribute value will be int (src_output -// for original edge). -extern const char kHostToOutsideCompilationSrcOutputAttrName[]; - // Attribute indicating that this is an Placeholder node added to act as a // temporary input node for an outside compilation node. Attribute value will be // string (original input node name). @@ -106,27 +58,6 @@ extern const char kOutsideCompilationSrcOutputAttrName[]; // (node names). extern const char kXlaControlDependenciesWithinXlaClusterAttrName[]; -// Preprocesses edges between different XLA clusters for encapsulation. It will -// perform the following operations in order: -// -// 1a. For control edges between outside compilation and another XLA -// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName -// = XLA computation node name" to the outside compilation node. -// 1b. For control edges between different outside compilations (in different -// XLA computations), remove the edge and add attr -// "kXlaControlDependenciesAttrName = src node name" to dst node. -// 1c. For control edges between outside compilation and host computation, -// remove the edge and add attr "kXlaControlDependenciesAttrName = src node -// name" to dst node. -// 2. For data edges between different XLA computations, if either src or dst -// is outside compilation, add an Identity node in between the edge. The -// identity node will have attr kBridgeSourceNodeAttrName. -// 3. For data edges between outside compilation and host computation, remove -// the edge and create a Placeholder node as dst node's input. -Status PreprocessForEncapsulation(Graph* g, - const string& xla_computation_attr_name, - const string& outside_compilation_attr_name); - // Information for XLA computation. struct XlaClusterInfo { // Add an explicitly-defined default constructor for this class. @@ -158,24 +89,6 @@ struct XlaClusterInfo { const std::map host_compute_core; }; -// Postprocesses edges between different XLA clusters for encapsulation. This -// function reverts what `PreprocessForEncapsulation` did. It will perform the -// following operations in order: -// -// 1. Remove Placeholder nodes between outside compilation and host computation -// (created in `PreprocessForEncapsulation` step 3). -// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2. -// 3a. Reconnect control edges between outside compilation and another XLA -// computation (marked by `PreprocessForEncapsulation` step 1a). -// 3b. Reconnect control edges between different outside compilations (marked by -// `PreprocessForEncapsulation` step 1b). -// 3c. Reconnect control edges between outside compilation and host computation -// (marked by `PreprocessForEncapsulation` step 1c). -Status PostprocessForEncapsulation( - Graph* g, const string& xla_computation_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters); - // Preprocesses edges within the same XLA cluster. It will perform the following // operations in order: // diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc index 3b8b49cb92f3e453883a8e64e12ce3748a5173f6..3bb979e0698d2d6be42ed5bae66c25267928192c 100644 --- a/tensorflow/compiler/jit/encapsulate_util_test.cc +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -38,24 +38,11 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { Graph g(OpRegistry::Global()); TF_CHECK_OK(s.ToGraph(&g)); - // "add" node is outside compilation node, "identity" node is XLA node. - auto node_index = g.BuildNodeNameIndex(); - Node *add_node = node_index["add"], *identity_node = node_index["identity"]; - add_node->AddAttr("_xla", "cluster"); - add_node->AddAttr("_oc", "cluster"); - identity_node->AddAttr("_xla", "cluster"); - TF_CHECK_OK( - PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc")); + TF_CHECK_OK(PerformStaticShapeInferenceBeforeEncapsulation(&g)); - // Check that only "add" node now has _xla_inferred_shapes attr. - std::vector nodes_with_inferred_shape; - for (Node *n : g.nodes()) { - if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) { - nodes_with_inferred_shape.push_back(n); - } - } - EXPECT_EQ(nodes_with_inferred_shape.size(), 1); - EXPECT_EQ(nodes_with_inferred_shape[0], add_node); + // Check that "add" node now has _xla_inferred_shapes attr. + auto node_index = g.BuildNodeNameIndex(); + Node *add_node = node_index["add"]; std::vector output_shapes; TF_CHECK_OK(GetNodeAttr(add_node->attrs(), kXlaInferredShapesAttrName, &output_shapes)); @@ -66,329 +53,4 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { EXPECT_EQ(shape_proto.dim(0).size(), 2); } -TEST(PreprocessForEncapsulationTest, ControlEdges) { - // Build the graph: - // "const_0" and "const_1" in host computation - // "add" = "const_0" + "const_1" in XLA computation 0 - // "identity0" = "add" in XLA computation 0 & outside compilation 0 - // "identity1" = "identity0" in XLA computation 0 - // "identity2" = "identity1" in host computation - // "identity3" = "identity2" in XLA computation 1 - // "identity4" = "identity3" in XLA computation 1 & outside compilation 1 - // "identity5" = "identity4" in XLA computation 1 - // "identity6" = "identity5" in host computation - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); - Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); - Output add = ops::Add(s.WithOpName("add"), const_0, const_1); - Output identity0 = ops::Identity(s.WithOpName("identity0"), add); - Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); - Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); - Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); - Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3); - Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr, and add control edges. - Node *const0_node = node_index["const_0"], *add_node = node_index["add"], - *identity0_node = node_index["identity0"], - *identity1_node = node_index["identity1"], - *identity2_node = node_index["identity2"], - *identity3_node = node_index["identity3"], - *identity4_node = node_index["identity4"], - *identity5_node = node_index["identity5"]; - add_node->AddAttr("_xla", "0"); - identity0_node->AddAttr("_xla", "0"); - identity0_node->AddAttr("_oc", "0"); - identity1_node->AddAttr("_xla", "0"); - identity3_node->AddAttr("_xla", "1"); - identity4_node->AddAttr("_xla", "1"); - identity4_node->AddAttr("_oc", "0"); - identity5_node->AddAttr("_xla", "1"); - // Case 1a: control edges between outside compilation and another XLA - // computation. - g.AddControlEdge(identity0_node, identity3_node); - g.AddControlEdge(identity1_node, identity4_node); - // Case 1b: control edges between different outside compilations. - g.AddControlEdge(identity0_node, identity4_node); - // Case 1c: control edges between outside compilation and host computation. - g.AddControlEdge(const0_node, identity0_node); - g.AddControlEdge(identity0_node, identity2_node); - - TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - - // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name" - // to the outside compilation node. - std::vector attr; - TF_CHECK_OK(GetNodeAttr(identity0_node->def(), - kXlaConnectedToOtherXlaComputationAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "1"); - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity4_node->def(), - kXlaConnectedFromOtherXlaComputationAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "0"); - // Case 1b: add attr "_xla_control_deps = src node name" to dst node. - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity4_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "identity0"); - // Case 1c: add attr "_xla_control_deps = src node name" to dst node. - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity0_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "const_0"); - attr.clear(); - TF_CHECK_OK(GetNodeAttr(identity2_node->def(), - kXlaControlDependenciesAttrName, &attr)); - EXPECT_EQ(attr.size(), 1); - EXPECT_EQ(attr[0], "identity0"); -} - -TEST(PreprocessForEncapsulationTest, DataEdges) { - // Build the graph: - // "const_0" and "const_1" in host computation - // "identityn0" = ("const_0", "const_1") in host computation 0 - // "add0" = "const_0" + "const_1" in XLA computation 0 - // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0 - // "identity0" = "add1" in XLA computation 0 - // "add2" = "add1" + "identity0" in host computation - // "add3" = "add1" + "add2" in XLA computation 1 - // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0 - // "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 & - // outside compilation 0 - // "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 & - // outside compilation 0 - // "identity1" = "add4" in XLA computation 1 - // "identity2" = "identity1" in host computation - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); - Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); - auto identityn0 = - ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1}); - Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1); - Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0); - Output identity0 = ops::Identity(s.WithOpName("identity0"), add1); - Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0); - Output add3 = ops::Add(s.WithOpName("add3"), add1, add2); - Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2); - Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]); - auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"), - {identityn0[0], identityn0[1]}); - Output identity1 = ops::Identity(s.WithOpName("identity1"), add4); - Output identity2 = ops::Identity(s.WithOpName("identity2"), add4); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr. - Node *add0_node = node_index["add0"], *add1_node = node_index["add1"], - *identity0_node = node_index["identity0"], - *add3_node = node_index["add3"], *add4_node = node_index["add4"], - *add5_node = node_index["add5"], - *identityn1_node = node_index["identityn_1"], - *identity1_node = node_index["identity1"]; - add0_node->AddAttr("_xla", "0"); - add1_node->AddAttr("_xla", "0"); - add1_node->AddAttr("_oc", "0"); - identity0_node->AddAttr("_xla", "0"); - add3_node->AddAttr("_xla", "1"); - add4_node->AddAttr("_xla", "1"); - add4_node->AddAttr("_oc", "0"); - add5_node->AddAttr("_xla", "1"); - add5_node->AddAttr("_oc", "0"); - identityn1_node->AddAttr("_xla", "1"); - identityn1_node->AddAttr("_oc", "0"); - identity1_node->AddAttr("_xla", "1"); - - TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - - // Check input nodes for related data edges. - node_index = g.BuildNodeNameIndex(); - // Step 2: add an Identity node between different XLA computations. - Node *bridge_add1_add3 = node_index["bridge_add1_add3"]; - EXPECT_NE(bridge_add1_add3, nullptr); - string str; - TF_CHECK_OK( - GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str)); - EXPECT_EQ(str, "add1"); - Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"]; - EXPECT_NE(bridge_identity0_add4, nullptr); - // Step 3: add placeholder for edges between host computation and outside - // compilation. - EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0"); - Node *add1_oc_to_host_placeholder = - node_index["add1_oc_to_host_placeholder_0"]; - TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), - kOutsideCompilationToHostOriginalNodeAttrName, &str)); - EXPECT_EQ(str, "add1"); - int i; - TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), - kOutsideCompilationToHostSrcOutputAttrName, &i)); - EXPECT_EQ(i, 0); - add4_node = node_index["add4"]; - ASSERT_NE(add4_node, nullptr); - EXPECT_EQ(add4_node->def().input(0), - "bridge_identity0_add4_host_to_oc_placeholder_0"); - Node *identity0_host_to_oc_placeholder = - node_index["bridge_identity0_add4_host_to_oc_placeholder_0"]; - TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), - kHostToOutsideCompilationOriginalNodeAttrName, &str)); - EXPECT_EQ(str, "bridge_identity0_add4"); - TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), - kHostToOutsideCompilationSrcOutputAttrName, &i)); - EXPECT_EQ(i, 0); - - // Check different placeholder nodes are created for different src_output. - Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"], - *placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"]; - EXPECT_NE(placeholder0, nullptr); - EXPECT_NE(placeholder1, nullptr); - // Check we only have 2 placeholder nodes created for "identityn_0". - int placeholder_count = 0; - for (Node *n : g.nodes()) { - if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) { - string attr; - TF_CHECK_OK(GetNodeAttr( - n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr)); - if (attr == "identityn_0") { - ++placeholder_count; - } - } - } - EXPECT_EQ(placeholder_count, 2); -} - -TEST(PostprocessForEncapsulationTest, ControlEdges) { - // Build the graph: - // "const0" - // "identity0" = "const0" (XLA computation 0) - // "identity1" = "identity0" - // "identity2" = "identity1" (XLA computation 1) - // "identity3" = "identity2" - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); - Output identity0 = ops::Identity(s.WithOpName("identity0"), const0); - Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); - Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); - Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set XLA computation/outside compilation attr, and add control edges. - Node *const0_node = node_index["const0"], - *identity0_node = node_index["identity0"], - *identity1_node = node_index["identity1"], - *identity2_node = node_index["identity2"], - *identity3_node = node_index["identity3"]; - identity1_node->AddAttr(kXlaConnectedFromOtherXlaComputationAttrName, - std::vector{"0"}); - identity1_node->AddAttr(kXlaConnectedToOtherXlaComputationAttrName, - std::vector{"1"}); - identity3_node->AddAttr(kXlaControlDependenciesAttrName, - std::vector{"const0", "identity1"}); - - std::unordered_map clusters; - clusters["0"].node = identity0_node; - clusters["1"].node = identity2_node; - TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); - - // Case 3a: we have control edge identity0 -> identity1, and identity1 -> - // identity2. - bool edge_identity0_identity1 = false, edge_identity1_identity2 = false; - for (const Edge *e : g.edges()) { - if (!e->IsControlEdge()) { - continue; - } - if (e->src() == identity0_node && e->dst() == identity1_node) { - edge_identity0_identity1 = true; - } else if (e->src() == identity1_node && e->dst() == identity2_node) { - edge_identity1_identity2 = true; - } - } - EXPECT_TRUE(edge_identity0_identity1); - EXPECT_TRUE(edge_identity1_identity2); - // Case 3b: we have control edge const0 -> identity3, and identity1 -> - // identity3. - bool edge_const0_identity3 = false, edge_identity1_identity3 = false; - for (const Edge *e : g.edges()) { - if (!e->IsControlEdge()) { - continue; - } - if (e->src() == const0_node && e->dst() == identity3_node) { - edge_const0_identity3 = true; - } else if (e->src() == identity1_node && e->dst() == identity3_node) { - edge_identity1_identity3 = true; - } - } - EXPECT_TRUE(edge_const0_identity3); - EXPECT_TRUE(edge_identity1_identity3); -} - -TEST(PostprocessForEncapsulationTest, DataEdges) { - // Build the graph: - // "const0" in outside compilation "0" - // "placeholder0" (for "const0") in host computation - // "add0" = "placeholder0" + "placeholder0" in host computation - // "placeholder1" (for "add0") in outside compilation 1 - // "add1" = "placeholder1" + "placeholder1" in outside compilation 1 - // - // "bridge" = "placeholder0" in host computation - // "placeholder2" (for "bridge") in outside compilation 1 - // "add2" = "placeholder2" + "placeholder2" in outside compilation 1 - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); - Output placeholder0 = - ops::Placeholder(s.WithOpName("placeholder0"), DT_INT32); - Output add0 = ops::Add(s.WithOpName("add0"), placeholder0, placeholder0); - Output placeholder1 = - ops::Placeholder(s.WithOpName("placeholder1"), DT_INT32); - Output add1 = ops::Add(s.WithOpName("add1"), placeholder1, placeholder1); - Output bridge = ops::Identity(s.WithOpName("bridge"), placeholder0); - Output placeholder2 = - ops::Placeholder(s.WithOpName("placeholder2"), DT_INT32); - Output add2 = ops::Add(s.WithOpName("add2"), placeholder2, placeholder2); - Graph g(OpRegistry::Global()); - TF_CHECK_OK(s.ToGraph(&g)); - auto node_index = g.BuildNodeNameIndex(); - - // Set related attributes. - Node *placeholder0_node = node_index["placeholder0"]; - placeholder0_node->AddAttr(kOutsideCompilationToHostOriginalNodeAttrName, - "const0"); - placeholder0_node->AddAttr(kOutsideCompilationToHostSrcOutputAttrName, 0); - Node *placeholder1_node = node_index["placeholder1"]; - placeholder1_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, - "add0"); - placeholder1_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); - Node *bridge_node = node_index["bridge"]; - bridge_node->AddAttr(kBridgeSourceNodeAttrName, "const0"); - Node *placeholder2_node = node_index["placeholder2"]; - placeholder2_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, - "bridge"); - placeholder2_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); - - std::unordered_map clusters; - TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); - - // Result graph should be: - // "add0" = "const0" + "const0" - // "add1" = "add0" + "add0" - // "add2" = "const0" + "const0" - node_index = g.BuildNodeNameIndex(); - EXPECT_EQ(node_index.size(), 6); - EXPECT_EQ(node_index["add0"]->def().input(0), "const0:0"); - EXPECT_EQ(node_index["add0"]->def().input(1), "const0:0"); - EXPECT_EQ(node_index["add1"]->def().input(0), "add0:0"); - EXPECT_EQ(node_index["add1"]->def().input(1), "add0:0"); - EXPECT_EQ(node_index["add2"]->def().input(0), "const0:0"); - EXPECT_EQ(node_index["add2"]->def().input(1), "const0:0"); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index d334100aa4a915a87fb05d371e0e3379a7ee05f2..ec745cdbb7e237f8b4935dd41e9791fc75f5355d 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -297,6 +297,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, NodeDef def; def.set_name(launch->name()); + MergeDebugInfo(NodeDebugInfo(launch->def()), &def); // Target the XLA CPU/GPU backends. VLOG(2) << "Replacing with XlaLaunch"; diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index e3c7e2f89be9b37b51a633dabb099969c181013f..2a770c527b2fae91352fd17dacb13495a3a73f34 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -20,14 +20,17 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" namespace tensorflow { @@ -98,9 +101,12 @@ xla::StatusOr BuildRecvAtHostNode( recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. - recv_at_host_builder.Attr("device_ordinal", 0); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + recv_at_host_builder.Attr("device_ordinal", device_ordinal_value); recv_at_host_builder.Attr( "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true); recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING); TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def)); Status s; @@ -197,9 +203,12 @@ xla::StatusOr BuildSendFromHostNode( send_from_host_builder.Attr("Tinputs", send_from_host_dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. - send_from_host_builder.Attr("device_ordinal", 0); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + send_from_host_builder.Attr("device_ordinal", device_ordinal_value); send_from_host_builder.Attr( "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true); std::vector inputs(send_from_host_dtypes.size()); for (auto* n : ret_nodes) { int index; @@ -300,6 +309,10 @@ xla::StatusOr BuildXlaHostComputeNodeDef( host_compute_builder.Attr("tpu_core", core); } + // Set input tokens. + host_compute_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + // Populate inputs. std::vector input_dtypes; TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes)); @@ -322,6 +335,38 @@ xla::StatusOr BuildXlaHostComputeNodeDef( return new_def; } +Status ValidateOutsideCompilationCallNode(Node* call_node) { + // DT_INT64 as input/output for outside compilation is not supported yet: + // b/120809951. + for (const Edge* e : call_node->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + DataType dtype = e->src()->output_type(e->src_output()); + if (dtype == DT_INT64) { + return errors::Unimplemented( + "int64 input for outside compilation is not supported yet: " + "b/120809951. Please cast output of node ", + e->src()->DebugString(), + " to int32 before feeding it into outside compilation."); + } + } + for (const Edge* e : call_node->out_edges()) { + if (e->IsControlEdge()) { + continue; + } + DataType dtype = e->dst()->input_type(e->dst_input()); + if (dtype == DT_INT64) { + return errors::Unimplemented( + "int64 output for outside compilation is not supported yet: " + "b/120809951. Please cast input of node ", + e->dst()->DebugString(), + " to int32 before returning it from outside compilation."); + } + } + return Status::OK(); +} + // Replace outside compilation function call node with XlaHostCompute node. // If the function call node has no input/output edges, we will just remove it // and not create a XlaHostCompute node. @@ -357,6 +402,51 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( return Status::OK(); } +// Resets "device_ordinal" attr to placeholder value for related nodes +// (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes +// containing XlaRecvAtHost/XlaSendFromHost). +Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + for (Node* n : g->nodes()) { + if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { + continue; + } + + if (n->type_string() == "_XlaRecvAtHost" || + n->type_string() == "_XlaSendFromHost") { + n->ClearAttr("device_ordinal"); + n->AddAttr("device_ordinal", device_ordinal_value); + } else if (n->type_string() == "If") { + for (const string& attr_name : + std::vector{"then_branch", "else_branch"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + n->ClearAttr(attr_name); + n->AddAttr(attr_name, branch_func); + } + } else if (n->type_string() == "While") { + for (const string& attr_name : std::vector{"cond", "body"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + n->ClearAttr(attr_name); + n->AddAttr(attr_name, branch_func); + } + } else if (HasNodeAttr(n->def(), "device_ordinal")) { + // Function call node containing outside compilation. + n->ClearAttr("device_ordinal"); + n->AddAttr("device_ordinal", device_ordinal_value); + } else { + return errors::Internal("Unknown node marked with ", + kXlaHasHostTransferAttrName, ": ", + n->DebugString()); + } + } + return Status::OK(); +} + // For an XLA computation, builds host side graph given all outside compilation // graphs inside it. The host side graph contains: // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and @@ -368,8 +458,8 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( Status ConstructHostGraph( const string& xla_cluster_name, const string& outside_compilation_attr_name, const std::vector& outside_compilation_host_graphs, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { - host_graph->reset(new Graph(fld)); + FunctionLibraryDefinition* fld, const string& host_graph_func_name) { + Graph host_graph(fld); // Create sequencer node in host graph. NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"), @@ -378,24 +468,34 @@ Status ConstructHostGraph( NodeDef sequencer_def; TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def)); Status s; - Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s); + Node* sequencer = host_graph.AddNode(sequencer_def, &s); TF_RETURN_IF_ERROR(s); // Create key placeholder in host graph. TF_ASSIGN_OR_RETURN( Node * key_placeholder, - AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get())); + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); // For each outside compilation graph, copy them to host graph with the // following changes: // a) Use key_placeholder in host graph instead of its own. - // b) Add control edge from RecvAtHost/SendFromHost to sequencer. + // b) Add control edge from host transfer nodes (XlaRecvAtHost, + // XlaSendFromHost, If/While nodes containing + // XlaRecvAtHost/XlaSendFromHost) to sequencer node. // c) Clear node_def.device(), so device placer won't get confused. for (const string& host_func : outside_compilation_host_graphs) { VLOG(4) << "Expanding host graph " << host_func; + // Temporarily use "0" as "device_ordinal". It will be reset to placeholder + // value after we expanded all host graphs. We cannot just use placeholder + // value here because FunctionDef instantiation does not allow placeholder + // value for attributes. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; FunctionBody* host_fbody = nullptr; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(host_func), AttrSlice(), fld, + *fld->Find(host_func), AttrSlice(&attrs), fld, [&](const string& op, const OpDef** sig) { return fld->LookUpOpDef(op, sig); }, @@ -408,8 +508,8 @@ Status ConstructHostGraph( FixupSourceAndSinkEdges(host_fbody->graph); std::map node_map; - node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); - node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); + node_map[host_fbody->graph->source_node()] = host_graph.source_node(); + node_map[host_fbody->graph->sink_node()] = host_graph.sink_node(); Status s; ReverseDFS( *host_fbody->graph, /*enter=*/nullptr, @@ -431,7 +531,7 @@ Status ConstructHostGraph( NodeDef copy_def = n->def(); // Change c). copy_def.clear_device(); - copy = (*host_graph)->AddNode(copy_def, &s); + copy = host_graph.AddNode(copy_def, &s); if (!s.ok()) { return; } @@ -446,22 +546,23 @@ Status ConstructHostGraph( e->src()->DebugString()); return; } - (*host_graph) - ->AddEdge(node_map[e->src()], e->src_output(), copy, - e->dst_input()); + host_graph.AddEdge(node_map[e->src()], e->src_output(), copy, + e->dst_input()); } // Change b). - if (copy->type_string() == "_XlaRecvAtHost" || - copy->type_string() == "_XlaSendFromHost") { - (*host_graph)->AddControlEdge(copy, sequencer); + if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) { + host_graph.AddControlEdge(copy, sequencer); } }, NodeComparatorID()); + if (!s.ok()) { return s; } } + // Reset "device_ordinal" to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(&host_graph)); // sequencer and key_placeholder might be dead nodes. Prune them if necessary. // - sequencer should be pruned iff it has no input control edges from @@ -470,21 +571,30 @@ Status ConstructHostGraph( // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost. // We don't need to do anything special. if (!sequencer->in_edges().empty()) { - (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node()); + host_graph.AddControlEdge(sequencer, host_graph.sink_node()); } PruneForReverseReachability( - host_graph->get(), - std::unordered_set{(*host_graph)->sink_node()}); + &host_graph, std::unordered_set{host_graph.sink_node()}); // Postprocess edges between different outside compilations. TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations( - host_graph->get(), outside_compilation_attr_name)); + &host_graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("extract_outside_compilation_host_graph_for_", xla_cluster_name), - **host_graph, fld); + host_graph, fld); + } + + FunctionDef host_graph_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(host_graph, host_graph_func_name, &host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef)); } return Status::OK(); @@ -492,8 +602,28 @@ Status ConstructHostGraph( // Expand XLA computation's outside compilation host side graph into main graph. // Add a control edge between sequencer node and the XLA computation node. -Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph, +Status ExpandHostGraphIntoMainGraph(Graph* main_graph, + FunctionLibraryDefinition* fld, + const string& host_graph_func_name, Node* xla_computation_node) { + // Temporarily use "0" as "device_ordinal". It will be rewritten with the + // correct value in a later pass. We cannot just use placeholder value here + // because FunctionDef instantiation does not allow placeholder value for + // attributes. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(host_graph_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* host_graph = fbody->graph; + // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse // reachable from sink node so all nodes will be copied. // TODO(b/77601805): consolidate copy graph functions. @@ -545,23 +675,25 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph, return s; } -// Rewrites shape inference graph for outside compilation. -// 1. If the outside compilation is a "top-level" one (not in a function of any -// If/While/etc.), this shape inference graph might have host computation to -// outside compilation placeholder nodes, which will cause shape inference to -// fail. However, those nodes are not in `host_graph` any more (because we -// have executed `PostprocessForEncapsultion`). In this case, we clear the -// graph, and copy SendFromHost with all its predecessors from `host_graph`. -// This case is detected by whether the SendFromHost node exists in -// `host_graph` as well. -// 2. Remove control edges, and prune nodes that are not useful for shape -// inference. +// Rewrites shape inference graph for outside compilation: +// 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from +// `host_graph`. Because we might still have outside compilation to outside +// compilation placeholder nodes in shape inference graph, which will prevent +// us from inferring XlaSendFromHost shape. But in `host_graph`, we already +// removed those placeholder nodes. +// 2) Remove control edges. +// 3) Prune nodes that are not useful for shape inference. Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, Graph* host_graph, FunctionLibraryDefinition* fld) { + // Use "0" as "device_ordinal". It does not matter for shape inference. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; FunctionBody* fbody = nullptr; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(shape_inference_graph_name), AttrSlice(), fld, + *fld->Find(shape_inference_graph_name), AttrSlice(&attrs), fld, [&](const string& op, const OpDef** sig) { return fld->LookUpOpDef(op, sig); }, @@ -650,6 +782,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, g->RemoveEdge(e); } } + // Nodes that are not reverse reachable from SendFromHost are not useful for // shape inference. Prune them. PruneForReverseReachability(g, @@ -669,6 +802,681 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, return Status::OK(); } +// Builds XlaSendToHost node which sends cond predicate to host. +xla::StatusOr BuildSendIfPredNode(const string& name, + const string& host_transfer_key, + Node* pred_node, Graph* g) { + NodeDefBuilder send_pred_builder(name, "XlaSendToHost"); + send_pred_builder.Attr("Tinput", DT_BOOL); + send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); + send_pred_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + send_pred_builder.Input(pred_node->name(), 0, DT_BOOL); + NodeDef send_pred_def; + TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def)); + Status s; + Node* send_pred_node = g->AddNode(send_pred_def, &s); + TF_RETURN_IF_ERROR(s); + g->AddEdge(pred_node, 0, send_pred_node, 0); + return send_pred_node; +} + +// Replaces key placeholder node with an _Arg node. +Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, + const string& func_name, + FunctionLibraryDefinition* fld) { + // Temporarily use "0" as "device_ordinal". It will be reset to placeholder + // value after rewriting. + AttrValue device_ordinal_attr; + device_ordinal_attr.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_attr; + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* g = fbody->graph; + + // Find or create the key placeholder node. + Node* key_placeholder = nullptr; + for (Node* n : g->nodes()) { + if (IsKeyPlaceholderNode(*n)) { + key_placeholder = n; + break; + } + } + if (!key_placeholder) { + TF_ASSIGN_OR_RETURN(key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, g)); + } + + // Build the _Arg node, and replace key placeholder node with it. + NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp); + arg_builder.Attr("T", DT_STRING); + arg_builder.Attr("index", 0); + NodeDef arg_def; + TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def)); + TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status()); + + // Reset "device_ordinal" to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g)); + + FunctionDef replace_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, func_name, &replace_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef)); + return Status::OK(); +} + +// Builds host side graph for If node. +Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, + const string& xla_cluster_name, + const string& if_node_name, + const string& host_transfer_key, + const string& host_graph_func_name, + FunctionLibraryDefinition* fld, + const string& then_branch_host_func_name, + const string& else_branch_host_func_name) { + Graph host_graph(fld); + string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: build XlaRecvAtHost node to recv predicate. + NodeDefBuilder recv_pred_builder( + absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost"); + recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL}); + recv_pred_builder.Attr("key", host_transfer_key); + recv_pred_builder.Attr("device_ordinal", device_ordinal_value); + recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + recv_pred_builder.Attr(outside_compilation_attr_name, + outside_compilation_name); + recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true); + recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING); + NodeDef recv_pred_def; + TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); + Status s; + Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0); + + // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key + // placeholder with an _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, then_branch_host_func_name, fld)); + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, else_branch_host_func_name, fld)); + + // Step 4: build If node to choose between `{then, else}_branch_host_graph`. + NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If"); + if_builder.Attr("Tcond", DT_BOOL); + if_builder.Attr("Tin", std::vector{DT_STRING}); + if_builder.Attr("Tout", std::vector{}); + NameAttrList host_then_branch, host_else_branch; + host_then_branch.set_name(then_branch_host_func_name); + (*host_then_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; + host_else_branch.set_name(else_branch_host_func_name); + (*host_else_branch.mutable_attr())["device_ordinal"] = device_ordinal_value; + if_builder.Attr("then_branch", host_then_branch); + if_builder.Attr("else_branch", host_else_branch); + if_builder.Attr(kXlaHasHostTransferAttrName, true); + if_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + if_builder.Attr(outside_compilation_attr_name, outside_compilation_name); + if_builder.Input(recv_pred_node->name(), 0, DT_BOOL); + std::vector if_inputs{ + {key_placeholder->name(), 0, DT_STRING}}; + if_builder.Input(if_inputs); + NodeDef if_def; + TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def)); + Node* if_node = host_graph.AddNode(if_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(recv_pred_node, 0, if_node, 0); + host_graph.AddEdge(key_placeholder, 0, if_node, 1); + + // Convert `host_graph` to function, and add a "device_ordinal" attr. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +// Rewrites loop cond to add a node which sends loop cond to host. +Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld, + const NameAttrList& loop_cond_func, + const string& while_node_name, + const string& host_transfer_key) { + // Instantiate the loop cond function. + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(loop_cond_func.name()), AttrSlice(&loop_cond_func.attr()), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* g = fbody->graph; + + // Find the _Retval node and the loop cond node. + Node* ret_node = nullptr; + for (Node* n : g->nodes()) { + if (n->type_string() == "_Retval") { + if (ret_node) { + return errors::Internal("Multiple return node for loop cond function ", + loop_cond_func.name(), ": ", + ret_node->DebugString(), " and ", + n->DebugString()); + } else { + ret_node = n; + } + } + } + if (!ret_node) { + return errors::Internal("No _Retval node for loop cond function ", + loop_cond_func.name()); + } + Node* loop_cond; + TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond)); + + // Build the XlaSendToHost node. + NodeDefBuilder send_loop_cond_builder( + absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost"); + send_loop_cond_builder.Attr("Tinput", DT_BOOL); + send_loop_cond_builder.Attr("key", + absl::StrCat(host_transfer_key, "_dtoh_0")); + send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL); + NodeDef send_loop_cond_def; + TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def)); + Status s; + Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s); + TF_RETURN_IF_ERROR(s); + g->AddEdge(loop_cond, 0, send_loop_cond_node, 0); + + // Replace original function. + FunctionDef replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef)); + + return Status::OK(); +} + +// Rewrites while loop cond function for host. +Status RewriteHostWhileLoopCond( + const string& cond_host_func_name, const string& while_node_name, + const string& host_transfer_key, const string& xla_cluster_attr_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, + const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + // Replace key placeholder node with _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, cond_host_func_name, fld)); + + // Instantiate cond function. + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_temp_value; + FunctionBody* cond_fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(cond_host_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &cond_fbody)); + std::unique_ptr cond_fbody_deleter(cond_fbody); + Graph* cond_graph = cond_fbody->graph; + Node* key_arg = nullptr; + for (Node* n : cond_graph->nodes()) { + if (n->type_string() == "_Arg") { + key_arg = n; + } + } + if (!key_arg) { + return errors::Internal( + "No _Arg node found for host compute key in function ", + cond_host_func_name); + } + + // Add an XlaRecvAtHost node to use as cond function return value. + // We don't need to set kXlaHasHostTransferAttrName for this node, because + // it's already added for the "While" node on the host. + NodeDefBuilder recv_pred_builder( + absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost"); + recv_pred_builder.Attr("Toutputs", std::vector{DT_BOOL}); + recv_pred_builder.Attr("key", host_transfer_key); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + recv_pred_builder.Attr("device_ordinal", device_ordinal_value); + recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + recv_pred_builder.Attr(outside_compilation_attr_name, + outside_compilation_name); + recv_pred_builder.Input(key_arg->name(), 0, DT_STRING); + NodeDef recv_pred_def; + TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def)); + Status s; + Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s); + TF_RETURN_IF_ERROR(s); + cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0); + NodeDefBuilder ret_builder( + absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval"); + ret_builder.Attr("T", DT_BOOL); + ret_builder.Attr("index", 0); + ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL); + NodeDef ret_def; + TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); + Node* ret_node = cond_graph->AddNode(ret_def, &s); + TF_RETURN_IF_ERROR(s); + cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0); + + // Reset device_ordinal to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph)); + + // Replace original function. + FunctionDef cond_replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*cond_graph, cond_host_func_name, &cond_replace_fdef)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef)); + + return Status::OK(); +} + +// Rewrites while loop body function for host. +Status RewriteHostWhileLoopBody( + const string& body_host_func_name, const string& while_node_name, + const string& host_transfer_key, const string& xla_cluster_attr_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, + const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + // Replace key placeholder node with _Arg node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, body_host_func_name, fld)); + + // Instantiate body function. + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map attrs; + attrs["device_ordinal"] = device_ordinal_temp_value; + FunctionBody* body_fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(body_host_func_name), AttrSlice(&attrs), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &body_fbody)); + std::unique_ptr body_fbody_deleter(body_fbody); + Graph* body_graph = body_fbody->graph; + Node* key_arg = nullptr; + for (Node* n : body_graph->nodes()) { + if (n->type_string() == "_Arg") { + key_arg = n; + } + } + if (!key_arg) { + return errors::Internal( + "No _Arg node found for host compute key in function ", + body_host_func_name); + } + + // Add a _Retval node to loop body. + NodeDefBuilder ret_builder( + absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval"); + ret_builder.Attr("T", DT_STRING); + ret_builder.Attr("index", 0); + ret_builder.Input(key_arg->name(), 0, DT_STRING); + NodeDef ret_def; + TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def)); + Status s; + Node* ret_node = body_graph->AddNode(ret_def, &s); + TF_RETURN_IF_ERROR(s); + body_graph->AddEdge(key_arg, 0, ret_node, 0); + + // Reset device_ordinal to placeholder value. + TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph)); + + // Replace original function. + FunctionDef body_replace_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body_graph, body_host_func_name, &body_replace_fdef)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(body_host_func_name, body_replace_fdef)); + + return Status::OK(); +} + +// Builds host side graph for while node. +Status BuildHostGraphForWhileNode( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const string& while_node_name, const string& host_transfer_key, + const string& host_graph_func_name, FunctionLibraryDefinition* fld, + const string& cond_host_func_name, const string& body_host_func_name) { + Graph host_graph(fld); + string outside_compilation_name = absl::StrCat("oc_while_", while_node_name); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: rewrite cond function. + TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond( + cond_host_func_name, while_node_name, host_transfer_key, + xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, + outside_compilation_name, fld)); + + // Step 3: rewrite body function. + TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody( + body_host_func_name, while_node_name, host_transfer_key, + xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, + outside_compilation_name, fld)); + + // Step 4: build While node. + NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name), + "While"); + while_builder.Attr("T", std::vector{DT_STRING}); + NameAttrList func; + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + (*func.mutable_attr())["device_ordinal"] = device_ordinal_value; + func.set_name(cond_host_func_name); + while_builder.Attr("cond", func); + func.set_name(body_host_func_name); + while_builder.Attr("body", func); + while_builder.Attr(kXlaHasHostTransferAttrName, true); + while_builder.Attr(xla_cluster_attr_name, xla_cluster_name); + while_builder.Attr(outside_compilation_attr_name, outside_compilation_name); + std::vector while_inputs{ + {key_placeholder->name(), 0, DT_STRING}}; + while_builder.Input(while_inputs); + NodeDef while_def; + TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def)); + Status s; + Node* while_node = host_graph.AddNode(while_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, while_node, 0); + + // Convert `host_graph` to function. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +// Builds host graph for func call nodes. +Status BuildHostGraphForFuncCallNode(const string& func_call_node_name, + const string& xla_cluster_name, + const string& func_call_host_func_name, + const string& host_graph_func_name, + FunctionLibraryDefinition* fld) { + Graph host_graph(fld); + AttrValue device_ordinal_value; + device_ordinal_value.set_placeholder("device_ordinal"); + + // Step 1: add key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph)); + + // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg + // node. + TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( + xla_cluster_name, func_call_host_func_name, fld)); + + // Step 3: build a function call node with `host_func_name`, with + // `key_placeholder` as input. + NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name), + func_call_host_func_name, fld); + call_builder.Input(key_placeholder->name(), 0, DT_STRING); + call_builder.Attr("device_ordinal", device_ordinal_value); + call_builder.Attr(kXlaHasHostTransferAttrName, true); + NodeDef call_def; + TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def)); + Status s; + Node* call_node = host_graph.AddNode(call_def, &s); + TF_RETURN_IF_ERROR(s); + host_graph.AddEdge(key_placeholder, 0, call_node, 0); + + // Convert `host_graph` to function, and add a "device_ordinal" attr. + FunctionDef oc_host_graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name, + &oc_host_graph_fdef)); + if (fld->Find(host_graph_func_name)) { + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef)); + } + + return Status::OK(); +} + +Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( + Graph* g, const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, + FunctionLibraryDefinition* fld, std::vector* host_graphs, + std::vector* shape_inference_graphs, + bool* has_outside_compilation) { + std::vector if_nodes, while_nodes, func_call_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "If") { + if_nodes.push_back(n); + } else if (n->type_string() == "While") { + while_nodes.push_back(n); + } else if (fld->Contains(n->type_string())) { + func_call_nodes.push_back(n); + } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) { + // Only gradient for user-defined function should be considered as + // function call node. + NameAttrList original_func; + TF_RETURN_IF_ERROR(GetNodeAttr( + n->def(), FunctionLibraryDefinition::kFuncAttr, &original_func)); + if (fld->Contains(original_func.name())) { + func_call_nodes.push_back(n); + } + } + } + + for (Node* n : func_call_nodes) { + // Extract outside compilation for the function call. + bool func_has_outside_compilation = false; + NameAttrList func; + func.set_name(n->type_string()); + typedef protobuf::Map AttrMap; + *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end()); + string new_func_name = absl::StrCat(n->name(), "_oc"); + string host_func_name = absl::StrCat("oc_func_call_host_", n->name()); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + func, new_func_name, host_func_name, host_compute_core, flr, fld, + shape_inference_graphs, &func_has_outside_compilation)); + + // If the function call does not have outside compilation, nothing to do. + if (!func_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change `n` to call the new function directly. + NodeDefBuilder replace_builder(n->name(), new_func_name, fld); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + replace_builder.Input(e->src()->name(), e->src_output(), + e->src()->output_type(e->src_output())); + } + for (const auto& attr : n->attrs()) { + replace_builder.Attr(attr.first, attr.second); + } + NodeDef replace_def; + TF_RETURN_IF_ERROR(replace_builder.Finalize(&replace_def)); + TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, replace_def)); + replace->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + + // Build host side graph for the function call. + string oc_host_graph_name = + absl::StrCat("oc_func_host_graph_", replace->name()); + TF_RETURN_IF_ERROR( + BuildHostGraphForFuncCallNode(replace->name(), xla_cluster_name, + host_func_name, oc_host_graph_name, fld)); + + // Record the host graph. + host_graphs->push_back(oc_host_graph_name); + } + + for (Node* n : if_nodes) { + // Instantiate "then_branch" and "else_branch". + NameAttrList then_branch, else_branch; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch)); + + // Extract outside compilation for then_branch and else_branch. + bool then_branch_has_outside_compilation = false; + bool else_branch_has_outside_compilation = false; + string then_branch_host_func_name = + absl::StrCat("oc_then_branch_host_if_", n->name()), + else_branch_host_func_name = + absl::StrCat("oc_else_branch_host_if_", n->name()); + string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"), + else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + then_branch, then_branch_xla_func_name, then_branch_host_func_name, + host_compute_core, flr, fld, shape_inference_graphs, + &then_branch_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + else_branch, else_branch_xla_func_name, else_branch_host_func_name, + host_compute_core, flr, fld, shape_inference_graphs, + &else_branch_has_outside_compilation)); + + // If then/else branch do not have outside compilation, nothing to do. + if (!then_branch_has_outside_compilation && + !else_branch_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change If node to call the new functions. + then_branch.set_name(then_branch_xla_func_name); + n->ClearAttr("then_branch"); + n->AddAttr("then_branch", then_branch); + else_branch.set_name(else_branch_xla_func_name); + n->ClearAttr("else_branch"); + n->AddAttr("else_branch", else_branch); + + string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); + + // XLA computation: add a SendToHost node to send cond predicate. + Node* pred_node; + TF_RETURN_IF_ERROR(n->input_node(0, &pred_node)); + TF_ASSIGN_OR_RETURN( + Node * send_pred_node, + BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()), + host_transfer_key, pred_node, g)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{send_pred_node->name()}); + + // Add a control edge from `send_pred_node` to If node, so XlaCompiler will + // visit If node after `send_pred_node`, thus the token output for + // `send_pred_node` has been generated. + g->AddControlEdge(send_pred_node, n); + + // Build host side graph for the "If" node. + string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForIfNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + then_branch_host_func_name, else_branch_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + } + + for (Node* n : while_nodes) { + // Instantiate "cond" and "body". + NameAttrList cond, body; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body)); + + // Extract outside compilation for cond and body. + bool cond_has_outside_compilation = false; + bool body_has_outside_compilation = false; + string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()), + body_host_func_name = absl::StrCat("oc_body_host_while_", n->name()); + string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), + body_xla_func_name = absl::StrCat(body.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr, + fld, shape_inference_graphs, &cond_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + body, body_xla_func_name, body_host_func_name, host_compute_core, flr, + fld, shape_inference_graphs, &body_has_outside_compilation)); + + // If cond/body do not have outside compilation, nothing to do. + if (!cond_has_outside_compilation && !body_has_outside_compilation) { + continue; + } + + *has_outside_compilation = true; + + // Change While node to call the new functions. + cond.set_name(cond_xla_func_name); + n->ClearAttr("cond"); + n->AddAttr("cond", cond); + body.set_name(body_xla_func_name); + n->ClearAttr("body"); + n->AddAttr("body", body); + + string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); + + // XLA computation: rewrite cond function to add a SendToHost node to send + // loop predicate. + TF_RETURN_IF_ERROR( + AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + + // Build host side graph for the "While" node. + string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + cond_host_func_name, body_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + } + + return Status::OK(); +} + } // namespace Status RewriteOutsideCompilationSubgraphFn::operator()( @@ -755,12 +1563,15 @@ Status RewriteOutsideCompilationSubgraphFn::operator()( // it with HostCompute node later. AddNodeAttr("_outside_compilation_subgraph", old_name, node_def); if (shapes) { - AddNodeAttr("shape_inference_graph", "", node_def); + NameAttrList shape_inference_graph; + AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", *shapes, node_def); } else { string shape_inference_func_name = absl::StrCat("_outside_compilation_shape_inference_", new_name); - AddNodeAttr("shape_inference_graph", shape_inference_func_name, node_def); + NameAttrList shape_inference_graph; + shape_inference_graph.set_name(shape_inference_func_name); + AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", std::vector{}, node_def); } AddNodeAttr("ancestors", std::vector{}, node_def); @@ -775,36 +1586,34 @@ Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, - const std::map& host_compute_core, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, - std::vector* shape_inference_graphs, + const string& host_graph_func_name, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, + FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, bool* has_outside_compilation) { - // Early return if function does not have any outside compilation nodes. + // Convert the function to graph. const string& func_name = func_name_attrs.name(); - const FunctionDef* fdef = fld->Find(func_name); - if (!fdef) { - return errors::Internal("Cannot find function ", func_name); - } + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR( + flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* fbody = flr->GetFunctionBody(handle); + + // Check if we have outside compilation nodes. *has_outside_compilation = false; - for (auto& node_def : fdef->node_def()) { - if (HasNodeAttr(node_def, outside_compilation_attr_name)) { + for (Node* n : fbody->graph->nodes()) { + if (HasNodeAttr(n->def(), outside_compilation_attr_name)) { *has_outside_compilation = true; break; } } - if (!has_outside_compilation) { - return Status::OK(); - } - - // Convert the function to graph. - FunctionBody* fbody = nullptr; - TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *fld->Find(func_name), AttrSlice(&func_name_attrs.attr()), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); + // We cannot early return here, because we might have outside compilation in + // If/While function body. // Preprocess edges between different outside compilations. They will be // restored in `ConstructHostGraph()`. @@ -835,11 +1644,11 @@ Status ExtractOutsideCompilationForFunction( // If we could not infer shapes for XlaSendFromHost inputs statically, we // will set the "shape_inference_graph" attribute. In that case, copy // outside compilation subgraph as shape inference graph in `fld`. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph", &shape_inference_graph)); - if (!shape_inference_graph.empty()) { - shape_inference_graphs->push_back(shape_inference_graph); + if (!shape_inference_graph.name().empty()) { + shape_inference_graphs->push_back(shape_inference_graph.name()); const FunctionDef* xla_fdef = fld->Find(n->name()); if (!xla_fdef) { @@ -847,9 +1656,9 @@ Status ExtractOutsideCompilationForFunction( } FunctionDef shape_inference_fdef = *xla_fdef; shape_inference_fdef.mutable_signature()->set_name( - shape_inference_graph); - if (fld->Find(shape_inference_graph)) { - TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph, + shape_inference_graph.name()); + if (fld->Find(shape_inference_graph.name())) { + TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(), shape_inference_fdef)); } else { TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); @@ -858,21 +1667,22 @@ Status ExtractOutsideCompilationForFunction( } } for (Node* n : outside_compilation_nodes) { + TF_RETURN_IF_ERROR(ValidateOutsideCompilationCallNode(n)); TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( graph_out.get(), n, host_compute_core)); } - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("extract_outside_compilation_for_func_after_", func_name), - *graph_out, fld); - } + + // Handle nodes with associated functions. + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions( + graph_out.get(), xla_cluster_attr_name, outside_compilation_attr_name, + xla_cluster_name, host_compute_core, flr, fld, + &outside_compilation_host_graphs, shape_inference_graphs, + has_outside_compilation)); // Construct host graph. - if (!outside_compilation_host_graphs.empty()) { - TF_RETURN_IF_ERROR( - ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, - outside_compilation_host_graphs, fld, host_graph)); - } + TF_RETURN_IF_ERROR(ConstructHostGraph( + xla_cluster_name, outside_compilation_attr_name, + outside_compilation_host_graphs, fld, host_graph_func_name)); // Remove the outside compilation graphs from function library. for (const string& func : outside_compilation_host_graphs) { @@ -883,20 +1693,31 @@ Status ExtractOutsideCompilationForFunction( FunctionDef updated_fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*graph_out, new_func_name, &updated_fdef)); + const FunctionDef* original_fdef = fld->Find(func_name); + if (original_fdef) { + for (const auto& attr : original_fdef->attr()) { + (*updated_fdef.mutable_attr())[attr.first] = attr.second; + } + } if (fld->Find(new_func_name)) { TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, updated_fdef)); } else { TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef)); } + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("extract_outside_compilation_for_func_after_", func_name), + *graph_out, fld); + } - return Status::OK(); + return ret_status; } Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, - FunctionLibraryDefinition* fld) { + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) { if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile("extract_outside_compilation_before", *g, fld); } @@ -909,24 +1730,17 @@ Status ExtractOutsideCompilation( auto const& host_compute_core = iter.second.host_compute_core; bool has_outside_compilation; - std::unique_ptr host_graph; + string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name()); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - func_name_attrs, func_name_attrs.name(), host_compute_core, fld, - &host_graph, &shape_inference_graphs, &has_outside_compilation)); - if (host_graph) { - TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(g, host_graph.get(), n)); - } - } - - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile("extract_outside_compilation_expanded", *g, - fld); + func_name_attrs, func_name_attrs.name(), host_graph_func_name, + host_compute_core, flr, fld, &shape_inference_graphs, + &has_outside_compilation)); + TF_RETURN_IF_ERROR( + ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n)); + TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name)); } - TF_RETURN_IF_ERROR(PostprocessForEncapsulation( - g, xla_cluster_attr_name, outside_compilation_attr_name, clusters)); - for (auto shape_inference_graph_name : shape_inference_graphs) { TF_RETURN_IF_ERROR( RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld)); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h index 2a4f07cca213d999202024294f5d8f94527059c3..d64cc2a103ed040cbf413ac736f97f84459e869b 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -88,9 +88,10 @@ Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, - const std::map& host_compute_core, - FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, - std::vector* shape_inference_graphs, bool* has_outside_compilation); + const string& host_graph_func_name, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, + FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, + bool* has_outside_compilation); // Rewrites XLA computation in `clusters` to replace outside compilation nodes // with XlaHostCompute, and moves those outside compilations into `g`. If shapes @@ -100,7 +101,7 @@ Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, - FunctionLibraryDefinition* fld); + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index bff956100da661b679b4557fce53671e6cef88c5..7c3a24feff81b21a5d2347d21fb80988bc3e6065 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -19,8 +19,11 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" @@ -29,6 +32,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -109,10 +114,10 @@ TEST(RewriteOutsideCompilationSubgraphFnTest, Basic) { } EXPECT_TRUE(has_control_edge_to_send_from_host); // Verify step 7: necessary attrs added to call_node_def. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, + EXPECT_EQ(shape_inference_graph.name(), "_outside_compilation_shape_inference_cluster_0"); } @@ -220,7 +225,42 @@ TEST(RewriteOutsideCompilationSubgraphFnTest, ShapesInferred) { EXPECT_EQ(shapes[0].dim_size(), 1); } -TEST(ExtractOutsideCompilationForFunctionTest, Basic) { +class ExtractOutsideCompilationForFunctionTest : public ::testing::Test { + public: + void SetUp() override { + SessionOptions session_options; + std::vector> devices; + TF_CHECK_OK(DeviceFactory::AddDevices( + session_options, "/job:localhost/replica:0/task:0", &devices)); + device_mgr_ = absl::make_unique(std::move(devices)); + } + + Status ExtractOutsideCompilationTest( + const string &xla_cluster_attr_name, + const string &outside_compilation_attr_name, + const string &xla_cluster_name, const NameAttrList &func_name_attrs, + const string &new_func_name, const string &host_graph_func_name, + const std::map &host_compute_core, + FunctionLibraryDefinition *fld, + std::vector *shape_inference_graphs, + bool *has_outside_compilation) { + OptimizerOptions opts; + pflr_ = absl::make_unique( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts, + /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + return ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + func_name_attrs, new_func_name, host_graph_func_name, host_compute_core, + flr, fld, shape_inference_graphs, has_outside_compilation); + } + + private: + std::unique_ptr device_mgr_; + std::unique_ptr pflr_; +}; + +TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) { // Build the XLA computation func. // "const0" // "identity0" = "const0" (outside compilation cluster "0") @@ -249,27 +289,26 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); // Get rewritten XLA computation function. - FunctionBody *fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), - AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - auto node_name_index = fbody->graph->BuildNodeNameIndex(); + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + auto node_name_index = xla_fbody->graph->BuildNodeNameIndex(); // Check XlaHostCompute nodes. Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"]; @@ -292,18 +331,31 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { EXPECT_EQ(shapes[0].dim_size(), 1); // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have // empty values. - string shape_inference_graph; + NameAttrList shape_inference_graph; TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, ""); + EXPECT_EQ(shape_inference_graph.name(), ""); TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, ""); + EXPECT_EQ(shape_inference_graph.name(), ""); // Check `shape_inference_graphs`. EXPECT_EQ(shape_inference_graphs.size(), 0); - // Check `host_graph`: verify we have key placeholder and sequencer. + // Check host graph: verify we have key placeholder and sequencer. + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; Node *key_placeholder = nullptr, *sequencer = nullptr; for (Node *n : host_graph->nodes()) { if (n->type_string() == "Placeholder" && @@ -348,7 +400,7 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { } } -TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { +TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { // Build the XLA computation func. // "const0" FunctionDefLibrary fdl; @@ -365,25 +417,37 @@ TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); - // Check `host_graph` is empty. - EXPECT_FALSE(host_graph); + // Check host graph is empty. + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + EXPECT_EQ(host_graph->num_nodes(), 2); } -TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { +TEST_F(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { // Build the XLA computation func. // "const0" - // "const1" (outside compilation clsuter "0") + // "const1" (outside compilation cluster "0") FunctionDefLibrary fdl; { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -401,31 +465,43 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { protobuf::Map attrs; std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::unique_ptr host_graph; std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); *name_attrs.mutable_attr() = attrs; - TF_CHECK_OK(ExtractOutsideCompilationForFunction( - "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", - host_compute_core, &fld, &host_graph, &shape_inference_graphs, + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, &has_outside_compilation)); // Check rewritten XLA graph: verify that we have no XlaHostCompute. - FunctionBody *fbody = nullptr; - TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), - AttrSlice(), &fld, - [&](const string &op, const OpDef **sig) { - return fld.LookUpOpDef(op, sig); - }, - &fbody)); - std::unique_ptr fbody_deleter(fbody); - for (Node *n : fbody->graph->nodes()) { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + for (Node *n : xla_fbody->graph->nodes()) { EXPECT_NE(n->type_string(), "XlaHostCompute"); } - // Check `host_graph`: verify we have no placeholder, but we have "const1". + // Check host graph: verify we have no placeholder, but we have "const1". + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; int num_key_placeholders = 0; for (Node *n : host_graph->nodes()) { if (n->type_string() == "Placeholder" && @@ -438,4 +514,468 @@ TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { EXPECT_NE(node_name_index.find("const1"), node_name_index.end()); } +REGISTER_OP("XlaSendToHost") + .Input("input: Tinput") + .Attr("Tinput: type") + .Attr("key: string") + .SetIsStateful(); + +REGISTER_OP("XlaRecvFromHost") + .Output("output: Toutput") + .Attr("Toutput: type") + .Attr("shape: shape") + .Attr("key: string") + .SetIsStateful(); + +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { + // Build the XLA computation func. + // "const0" (bool) + // "const1" (int32) + // "if0" (pred = "const0", input = "const1", then_branch = "true_fn", + // else_branch = "false_fn") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity_true_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_true_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_true_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *true_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "true_fn", true_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity_false_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_false_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_false_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *false_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "false_fn", false_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output cond = ops::Const(s.WithOpName("const0"), true, {2}); + Output input = ops::Const(s.WithOpName("const1"), 1, {2}); + NameAttrList true_fn; + true_fn.set_name("true_fn"); + NameAttrList false_fn; + false_fn.set_name("false_fn"); + auto if_op = ops::If(s.WithOpName("if"), cond, + std::initializer_list{cond, input}, {DT_INT32}, + true_fn, false_fn); + ops::_Retval retval(s.WithOpName("retval"), if_op.output[0], 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have XlaRecvAtHost to receive "If" predicate. + Node *recv_if_pred_node = node_name_index["recv_oc_if_pred_if"]; + EXPECT_NE(recv_if_pred_node, nullptr); + + // Verify we have an "If" to choose outside compilation between then_branch + // and else_branch, and it has `recv_if_pred_node` as cond input. + Node *if_oc_node = node_name_index["oc_if_if"]; + EXPECT_NE(if_oc_node, nullptr); + Node *if_oc_node_cond_input; + TF_CHECK_OK(if_oc_node->input_node(0, &if_oc_node_cond_input)); + EXPECT_EQ(if_oc_node_cond_input, recv_if_pred_node); + + // Check that then_branch outside compilation has node "identity_true_fn". + const FunctionDef *true_def = fld.Find("oc_then_branch_host_if_if"); + EXPECT_NE(true_def, nullptr); + bool has_identity_true_fn_node = false; + for (const auto &node_def : true_def->node_def()) { + if (node_def.name() == "identity_true_fn") { + has_identity_true_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_true_fn_node); + + // Check that else_branch outside compilation has node "identity_false_fn". + const FunctionDef *false_def = fld.Find("oc_else_branch_host_if_if"); + EXPECT_NE(false_def, nullptr); + bool has_identity_false_fn_node = false; + for (const auto &node_def : false_def->node_def()) { + if (node_def.name() == "identity_false_fn") { + has_identity_false_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_false_fn_node); + } + + // Check XLA graph. + { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + Graph *xla_graph = xla_fbody->graph; + auto node_name_index = xla_graph->BuildNodeNameIndex(); + + // Check that we have XlaSendToHost to send cond predicate to host, and + // there is a control edge to If node. + Node *send_if_pred_node = node_name_index["send_oc_if_pred_if"]; + EXPECT_NE(send_if_pred_node, nullptr); + bool has_control_edge_to_if = false; + for (const Edge *e : send_if_pred_node->out_edges()) { + if (e->IsControlEdge() && e->dst()->name() == "if") { + has_control_edge_to_if = true; + break; + } + } + EXPECT_TRUE(has_control_edge_to_if); + + // Check that the "If" node now has `send_if_pred_node` as attribute + // _xla_token_input_nodes. + Node *if_node = node_name_index["if"]; + EXPECT_NE(if_node, nullptr); + std::vector token_inputs; + TF_CHECK_OK( + GetNodeAttr(if_node->def(), "_xla_token_input_nodes", &token_inputs)); + EXPECT_THAT(token_inputs, ::testing::ElementsAre("send_oc_if_pred_if")); + } +} + +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { + // Build the XLA computation func. + // "const0" (bool) + // "while0" (input = "const0", cond = "cond_fn", body = "body_fn") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0); + Output identity = ops::Identity(s.WithOpName("identity_cond_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_cond_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_cond_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *cond_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cond_fn", cond_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_BOOL, 0); + Output identity = ops::Identity(s.WithOpName("identity_body_fn"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity_body_fn"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity_body_fn"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *body_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "body_fn", body_fn_fdef)); + } + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = ops::Const(s.WithOpName("const0"), true, {2}); + NameAttrList cond_fn; + cond_fn.set_name("cond_fn"); + NameAttrList body_fn; + body_fn.set_name("body_fn"); + auto while_op = + ops::While(s.WithOpName("while"), std::initializer_list{input}, + cond_fn, body_fn); + ops::_Retval retval(s.WithOpName("retval"), while_op.output[0], 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have an "While" to execute outside compilation. + Node *while_oc_node = node_name_index["oc_while_while"]; + EXPECT_NE(while_oc_node, nullptr); + + // Check that cond outside compilation has node "identity_cond_fn". + const FunctionDef *cond_def = fld.Find("oc_cond_host_while_while"); + EXPECT_NE(cond_def, nullptr); + bool has_identity_cond_fn_node = false; + for (const auto &node_def : cond_def->node_def()) { + if (node_def.name() == "identity_cond_fn") { + has_identity_cond_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_cond_fn_node); + + // Check that body outside compilation has node "identity_body_fn". + const FunctionDef *body_def = fld.Find("oc_body_host_while_while"); + EXPECT_NE(body_def, nullptr); + bool has_identity_body_fn_node = false; + for (const auto &node_def : body_def->node_def()) { + if (node_def.name() == "identity_body_fn") { + has_identity_body_fn_node = true; + break; + } + } + EXPECT_TRUE(has_identity_body_fn_node); + } + + // Check XLA graph. + { + // Verify that rewritten cond fn has XlaSendToHost to send loop predicate to + // host. + const FunctionDef *cond_def = fld.Find("cond_fn_oc"); + EXPECT_NE(cond_def, nullptr); + bool has_send_oc_while_cond_node = false; + for (const auto &node_def : cond_def->node_def()) { + if (node_def.name() == "send_oc_while_cond_while") { + has_send_oc_while_cond_node = true; + break; + } + } + EXPECT_TRUE(has_send_oc_while_cond_node); + } +} + +TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { + // Build the XLA computation func. + // "const0" (int32) + // "fn" (input = "const0") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg = ops::_Arg(s.WithOpName("arg"), DT_INT32, 0); + Output identity = ops::Identity(s.WithOpName("identity"), arg); + ops::_Retval retval(s.WithOpName("retval"), identity, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity"]->AddAttr("_oc", "0"); + PartialTensorShape shape({2}); + node_name_image["identity"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *true_fn_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "fn", true_fn_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + { + std::unique_ptr g(new Graph(&fld)); + + tensorflow::TensorProto tensor_proto; + tensor_proto.set_dtype(tensorflow::DT_INT32); + tensorflow::TensorShapeProto shape; + shape.add_dim()->set_size(2); + *tensor_proto.mutable_tensor_shape() = shape; + for (int i = 0; i < 2; ++i) { + tensor_proto.add_int_val(1); + } + NodeDef const_def; + TF_CHECK_OK(NodeDefBuilder("const", "Const") + .Attr("dtype", DT_INT32) + .Attr("value", tensor_proto) + .Finalize(&const_def)); + Status s; + Node *const_node = g->AddNode(const_def, &s); + TF_CHECK_OK(s); + + NodeDef fn_def; + TF_CHECK_OK(NodeDefBuilder("fn", "fn", &fld) + .Input("const", 0, DT_INT32) + .Finalize(&fn_def)); + Node *fn_node = g->AddNode(fn_def, &s); + TF_CHECK_OK(s); + g->AddEdge(const_node, 0, fn_node, 0); + + NodeDef ret_def; + TF_CHECK_OK(NodeDefBuilder("ret", "_Retval") + .Attr("index", 0) + .Attr("T", DT_INT32) + .Input("fn", 0, DT_INT32) + .Finalize(&ret_def)); + Node *ret_node = g->AddNode(ret_def, &s); + TF_CHECK_OK(s); + g->AddEdge(fn_node, 0, ret_node, 0); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + TF_CHECK_OK(fld.AddFunctionDef(*xla_fdef)); + } + + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationTest( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", "host_graph", + host_compute_core, &fld, &shape_inference_graphs, + &has_outside_compilation)); + + // Check host graph. + { + FunctionBody *host_fbody = nullptr; + AttrValue device_ordinal_temp_value; + device_ordinal_temp_value.set_i(0); + protobuf::Map host_func_attrs; + host_func_attrs["device_ordinal"] = device_ordinal_temp_value; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + Graph *host_graph = host_fbody->graph; + auto node_name_index = host_graph->BuildNodeNameIndex(); + + // Verify we have call node for outside compilation in `fn`. + Node *call_node = node_name_index["oc_call_fn"]; + EXPECT_NE(call_node, nullptr); + + FunctionBody *call_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("oc_func_call_host_fn"), AttrSlice(&host_func_attrs), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &call_fbody)); + std::unique_ptr call_fbody_deleter(call_fbody); + + // Verify we have _XlaRecvAtHost and _XlaSendFromHost nodes. + bool has_recv = false, has_send = false; + for (Node *n : call_fbody->graph->nodes()) { + if (n->type_string() == "_XlaRecvAtHost") { + has_recv = true; + } else if (n->type_string() == "_XlaSendFromHost") { + has_send = true; + } + } + EXPECT_TRUE(has_recv); + EXPECT_TRUE(has_send); + } + + // Check XLA graph. + { + FunctionBody *xla_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("cluster_rewritten"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &xla_fbody)); + std::unique_ptr xla_fbody_deleter(xla_fbody); + Graph *xla_graph = xla_fbody->graph; + auto node_name_index = xla_graph->BuildNodeNameIndex(); + + // Check that we have call node. + Node *fn_node = node_name_index["fn"]; + EXPECT_NE(fn_node, nullptr); + EXPECT_EQ(fn_node->type_string(), "fn_oc"); + + FunctionBody *call_fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper( + *fld.Find("fn_oc"), AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &call_fbody)); + std::unique_ptr call_fbody_deleter(call_fbody); + + // Verify we have XlaHostCompute nodes. + bool has_hc = false; + for (Node *n : call_fbody->graph->nodes()) { + if (n->type_string() == "XlaHostCompute") { + has_hc = true; + } + } + EXPECT_TRUE(has_hc); + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 98e344b3a080aa8aab27cd41564a90427bac151e..fba69dfccc31e01e73d8f86006b41ce5e3283f15 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -68,7 +68,12 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { Flag("tf_xla_fusion_only", &mark_for_compilation_flags->tf_xla_fusion_only, "enable fusion of element-wise operations only using XLA when " - "global_jit_level is ON*.")}; + "global_jit_level is ON*."), + Flag("tf_xla_disable_deadness_safety_checks_for_debugging", + &mark_for_compilation_flags + ->tf_xla_disable_deadness_safety_checks_for_debugging, + "Disable deadness related safety checks when clustering (this is " + "unsound).")}; flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); } @@ -89,6 +94,8 @@ void AllocateAndParseFlags() { mark_for_compilation_flags->tf_xla_clustering_fuel = std::numeric_limits::max(); mark_for_compilation_flags->tf_xla_fusion_only = false; + mark_for_compilation_flags + ->tf_xla_disable_deadness_safety_checks_for_debugging = false; device_flags = new XlaDeviceFlags; device_flags->tf_xla_compile_on_demand = false; diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 5ddea588eef5270880d91623dc05893da265960a..ed7810fcfd85c17db70d42e691446b60dc696939 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -25,27 +25,39 @@ namespace tensorflow { // Flags associated with the XLA bridge's mark_for_compilation_pass module. struct MarkForCompilationPassFlags { - int32 tf_xla_auto_jit; // Control compilation of operators into XLA - // computations on CPU and GPU devices. 0 = use - // ConfigProto setting; -1 = off; 1 = on for things - // very likely to be improved; 2 = on for everything. - // Experimental. - int32 tf_xla_min_cluster_size; // Minimum number of operators in an XLA - // compilation. Ignored for operators placed - // on an XLA device or operators explicitly - // marked for compilation. - int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA - // compilation. - bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. - bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU - // via SessionOptions. - int64 tf_xla_clustering_fuel; // "Compiler fuel" for clustering. Only this - // many ops will be marked as eligible for - // clustering. - bool tf_xla_fusion_only; // This flag is effective only when global_jit_level - // is set to ON* and overrides its behavior. If - // true, enable fusion of element-wise operations - // only using XLA. + // Control compilation of operators into XLA computations on CPU and GPU + // devices. 0 = use ConfigProto setting; -1 = off; 1 = on for things very + // likely to be improved; 2 = on for everything. + // + // Experimental. + int32 tf_xla_auto_jit; + + // Minimum number of operators in an XLA compilation. Ignored for operators + // placed on an XLA device or operators explicitly marked for compilation. + int32 tf_xla_min_cluster_size; + + // Maximum number of operators in an XLA compilation. + int32 tf_xla_max_cluster_size; + + // Dump graphs during XLA compilation. + bool tf_xla_clustering_debug; + + // Enables global JIT compilation for CPU via SessionOptions. + bool tf_xla_cpu_global_jit; + + // "Compiler fuel" for clustering. Only this many ops will be marked as + // eligible for clustering. + int64 tf_xla_clustering_fuel; + + // tf_xla_fusion_only is effective only when global_jit_level is set to ON* + // and overrides its behavior. If true, enable fusion of element-wise + // operations only using XLA. + bool tf_xla_fusion_only; + + // If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then + // we do not do deadness related safety checks. This is unsound in general, + // but can be used as a debugging aid. + bool tf_xla_disable_deadness_safety_checks_for_debugging; }; // Flags associated with the XLA bridge's xla_device module. diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index ce53f70b79d97ab087fefe542920b33f883632a2..ebfffc3267e5acdf593bea2517c447083133e39c 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" +#include #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" @@ -144,7 +145,9 @@ SliceInputs MakeSliceIndexAndSizeInt64(const Scope& host_scope, // same constant value. This helps make the generated GraphDef more readable. class ConstantCache { public: - explicit ConstantCache(const Scope& s) : scope_(s) {} + explicit ConstantCache(const Scope& s, + const std::vector& control_deps) + : scope_(s), control_deps_(control_deps) {} Output Get1DHostConstant(int64 constant) { auto it = cache_.find(constant); @@ -152,6 +155,9 @@ class ConstantCache { Output new_const = ops::Const(scope_.WithOpName("const_", constant), {constant}); it = cache_.insert({constant, new_const}).first; + for (const Edge* e : control_deps_) { + scope_.graph()->AddControlEdge(e->src(), new_const.node()); + } } return it->second; } @@ -159,11 +165,13 @@ class ConstantCache { private: Scope scope_; std::unordered_map cache_; + std::vector control_deps_; }; // Returns a node computing the size of the Slice op with inputs `slice_inputs`. Status ComputeSliceSize(const Scope& host_scope, - const SliceInputs& slice_inputs, Output* size) { + const SliceInputs& slice_inputs, + std::vector control_deps, Output* size) { // If slice_size[i] >= 0 then slice_size[i] = slice_size[i]. // // If slice_size[i] == -1 then slice_size[i] = input_size[i] - @@ -183,7 +191,7 @@ Status ComputeSliceSize(const Scope& host_scope, ops::Shape(host_scope.WithOpName("input_shape"), slice_inputs.input, ops::Shape::OutType(DT_INT64)); - ConstantCache constant_pool(host_scope); + ConstantCache constant_pool(host_scope, control_deps); std::vector slice_size; for (int i = 0; i < slice_inputs.size_as_vector.size(); i++) { @@ -209,11 +217,16 @@ Status ComputeSliceSize(const Scope& host_scope, } // Trivial ConcatV2 nodes (with exactly one input) are disallowed. - *size = - slice_size.size() == 1 - ? slice_size[0] - : ops::Concat(host_scope.WithOpName("slice_size"), slice_size, - ops::Const(host_scope.WithOpName("concat_axis"), 0)); + if (slice_size.size() == 1) { + *size = slice_size[0]; + } else { + auto concat_axis = ops::Const(host_scope.WithOpName("concat_axis"), 0); + for (const Edge* e : control_deps) { + host_scope.graph()->AddControlEdge(e->src(), concat_axis.node()); + } + *size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size, + concat_axis); + } return Status::OK(); } @@ -237,9 +250,17 @@ Status ConvertTensorFlowSliceToStaticShapedSlice( SliceInputs slice_inputs_int64 = MakeSliceIndexAndSizeInt64(host_scope, slice_inputs); + // Create a list of all control dependencies to be copied when possibly + // replacing nodes related to slice_size. + Node* old_size; + std::vector old_size_ctrl_deps; + TF_RETURN_IF_ERROR(slice->input_node(2, &old_size)); + absl::c_copy_if(old_size->in_edges(), std::back_inserter(old_size_ctrl_deps), + [](const Edge* e) { return e->IsControlEdge(); }); + Output slice_size; - TF_RETURN_IF_ERROR( - ComputeSliceSize(host_scope, slice_inputs_int64, &slice_size)); + TF_RETURN_IF_ERROR(ComputeSliceSize(host_scope, slice_inputs_int64, + old_size_ctrl_deps, &slice_size)); *result = ops::Slice(main_scope.WithAssignedDevice(slice->assigned_device_name()) diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index a2f1b831ad7605237e23c15cc43b337e06265553..32e30216de565b4c1918903bf6c70c321c38cbb3 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -401,5 +401,36 @@ TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) { Name("begin/static_shaped_slice/static_shaped_slice"))), _))); } + +// New constants being created need to have control dependencies copied to +// ensure correct control flow analysis in TF V2. +TEST(SliceToDynamicSliceRewriteTest, WithControlDepsToConstant) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + // Add an additional dependency that should still exist in with the new size + // variables. + Output dependency = ops::Placeholder(root.WithOpName("dependency"), DT_BOOL); + root.graph()->AddControlEdge(dependency.node(), size.node()); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + // Check that the new constants have control dependencies. + Node* const_0 = testing::FindNodeByName(result.get(), + "slice/static_shaped_slice/const_0"); + EXPECT_NE(const_0, nullptr); + EXPECT_THAT(const_0, + NodeWith(Op("Const"), CtrlDeps(NodeWith(Op("Placeholder"), + Name("dependency"))))); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 0583774714c6db7a2fa515fc8a0d304e1898db97..bab824c15f8f27f5325e79cd92d50cdaad850233 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -25,6 +25,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 25796435a5c87af5e252981abf96833f4cda9a5e..20c2cd7e0561f92a01486102c4d2c572fd80c957 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -41,7 +42,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -86,7 +87,7 @@ bool IsDummyImplOp(absl::string_view op_name) { bool IsStatefulRandomOp(absl::string_view op_name) { return op_name == "RandomUniform" || op_name == "RandomShuffle" || op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || - op_name == "TruncatedNormal"; + op_name == "TruncatedNormal" || op_name == "Multinomial"; } bool OpProducesOrConsumesVariant(const Node& node) { @@ -677,12 +678,28 @@ Status MarkForCompilationPass::Run( VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; + // Deadness analysis expects a graph with source and sink edges properly + // connected but sometimes the incoming graph does not follow this invariant. + // So fix up the source and sink edges before calling into deadness analysis. + FixupSourceAndSinkEdges(options.graph->get()); + std::unique_ptr deadness; { XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness)); } + bool deadness_analysis_disabled = + GetMarkForCompilationPassFlags() + ->tf_xla_disable_deadness_safety_checks_for_debugging; + + if (deadness_analysis_disabled) { + LOG(WARNING) << "Deadness analysis was manually disabled via " + "--tf_xla_disable_deadness_safety_checks_for_debugging; " + "auto-clustering " + "is unsound!"; + } + auto is_compilable = [&](const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), @@ -715,9 +732,12 @@ Status MarkForCompilationPass::Run( // and some are dead) then don't compile it. XLA cannot represent the // deadness semantics of these nodes correctly and auto-clustering these // nodes can cause deadness to propagate to nodes that should be live. - if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { - VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; - return false; + if (!deadness_analysis_disabled) { + if (node->IsMerge() || + deadness->HasInputsWithMismatchingDeadness(*node)) { + VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; + return false; + } } // Check for fusable ops only if requested. @@ -1145,6 +1165,27 @@ Status MarkForCompilationPass::RunImpl( if (flags->tf_xla_clustering_debug) { dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph, options.flib_def); + + // We also dump out an annoated version of the TF graph where the nodes + // names are prefixed with the cluster names. This can help visualizing the + // clustering decisions on TensorBoard. + Graph new_graph((*options.graph)->op_registry()); + CopyGraph(**options.graph, &new_graph); + + for (Node* n : new_graph.nodes()) { + if (absl::optional cluster_name = + GetXlaClusterForNode(*n)) { + n->set_name(absl::StrCat(*cluster_name, "/", n->name())); + } else { + // There is room for improvement here. In particular, it may help to + // split these unclustered nodes into classes where every node in a + // specific class has edges to and from the same set of clusters. + n->set_name(absl::StrCat("unclustered/", n->name())); + } + } + + dump_graph::DumpGraphToFile("mark_for_compilation_annotated", new_graph, + options.flib_def); } VLogClusteringSummary(*graph); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index bf2c5508ea9e987e80093f4c2e15d3ff5191126f..c2b6250f738fafa35b2c5f79e97cf1281b50a316 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -151,7 +151,7 @@ TEST(XlaCompilationTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } -TEST(XlaCompilationTest, Complex128Unsupported) { +TEST(XlaCompilationTest, StringUnsupported) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; { @@ -159,10 +159,10 @@ TEST(XlaCompilationTest, Complex128Unsupported) { Node* a = ops::SourceOp( "Const", builder.opts() .WithName("A") - .WithAttr("dtype", DT_COMPLEX128) - .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape()))); - Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); - ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + .WithAttr("dtype", DT_STRING) + .WithAttr("value", Tensor(DT_STRING, TensorShape()))); + Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B")); + ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 42ea3926e16ae791dbe1bede3b8742383db7667c..e1fd2aaee2822daeffb415d053c9c4f56002a856 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -120,6 +120,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { NodeDef ndef = n->def(); ndef.set_name(absl::StrCat(n->name(), "/declustered")); + MergeDebugInfo(NodeDebugInfo(n->def()), &ndef); RemoveFromXlaCluster(&ndef); Status s; Node* cloned_node = graph->AddNode(ndef, &s); diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 38a54cc5efae35ad77b6dc8039c653e920cfc071..1d81a8f4fcbf050663626b1f7660afd71f4027bc 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index 80c691fe490c1092315708a2da754d367d585300..a27e0d9f2a6ecddfdbdb29be673084d77a178d8a 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -53,7 +53,15 @@ Status PropagateShapes(const Graph& graph, // shapes, even if no shape function is registered for a node. Status status = shape_refiner->AddNode(n); if (!status.ok()) { - VLOG(1) << "Shape inference failed for node: " << status; + VLOG(1) << "Shape inference failed for node " << n->name() << ": " + << status; + } else { + shape_inference::InferenceContext* context = shape_refiner->GetContext(n); + for (int i = 0; i < n->num_outputs(); i++) { + shape_inference::ShapeHandle handle = context->output(i); + VLOG(4) << "Output " << i << " for node " << n->name() << ": " + << context->DebugString(handle); + } } if (n->type_string() == "_Arg") { diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index fef28fc810cb4e544fe3f271f0b96cebd8a96779..80993861abba050fa3d6a133023d3c99f41f73e3 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -19,9 +19,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 3df5479a55e841380ca7b8cdd0add9fd17487091..611515cf33bc1abe21e06eb7f1513800276e095b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -38,6 +39,8 @@ limitations under the License. namespace tensorflow { +constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold; + XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) : client_(client), device_type_(std::move(device_type)) {} @@ -60,7 +63,7 @@ XlaCompilationCache::~XlaCompilationCache() { // about? } -string XlaCompilationCache::DebugString() { +string XlaCompilationCache::DebugString() const { return "XLA JIT compilation cache"; } @@ -68,9 +71,9 @@ string XlaCompilationCache::DebugString() { // arguments in the supplied list. string XlaCompilationCache::Signature::HumanString() const { string result = name; - for (const auto& a : arg_types) { - absl::StrAppend(&result, ",", DataTypeString(a.first), - a.second.DebugString()); + for (const auto& a : arg_shapes) { + absl::StrAppend(&result, ",", DataTypeString(a.first)); + absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]"); } for (const auto& v : arg_values) { @@ -81,7 +84,7 @@ string XlaCompilationCache::Signature::HumanString() const { bool XlaCompilationCache::Signature::operator==(const Signature& other) const { if (name != other.name) return false; - if (arg_types != other.arg_types) return false; + if (arg_shapes != other.arg_shapes) return false; if (arg_values.size() != other.arg_values.size()) return false; for (int i = 0; i < arg_values.size(); ++i) { @@ -97,10 +100,10 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const { uint64 XlaCompilationCache::Signature::Hash::operator()( const XlaCompilationCache::Signature& signature) const { uint64 h = std::hash()(signature.name); - for (const auto& arg : signature.arg_types) { + for (const auto& arg : signature.arg_shapes) { h = Hash64Combine(h, std::hash()(static_cast(arg.first))); - h = Hash64Combine(h, std::hash()(arg.second.dims())); - for (int dim : arg.second.dim_sizes()) { + h = Hash64Combine(h, std::hash()(arg.second.size())); + for (int dim : arg.second) { h = Hash64Combine(h, std::hash()(dim)); } } @@ -124,7 +127,7 @@ XlaCompilationCache::BuildSignature( break; case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kResource: - signature.arg_types.emplace_back(arg.type, arg.shape); + signature.arg_shapes.emplace_back(arg.type, arg.DimensionSizes()); break; default: return errors::InvalidArgument( diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 846d0c963dbfdf55f51120f2f138d12f5f63839b..7748b4700f39da4f952278ca6c6d2cadff4d3fb8 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -88,14 +88,16 @@ class XlaCompilationCache : public ResourceBase { xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } - string DebugString() override; + string DebugString() const override; // Describes the types, shapes and any compile-time constant arguments // to a kernel. Key that uniquely identifies a compilation output. struct Signature { string name; - std::vector> arg_types; + // List of Tensor types & shapes for compile-time constant arguments to the + // compilation, ordered by argument number. + std::vector>> arg_shapes; // List of Tensor values for compile-time constant arguments to the // compilation, ordered by argument number. Tensors must be in host memory. diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 1fe612d43d10030675cf307b109e4dcc89cb2d79..c7e8d61d280a33a83c3386d8ef801018634d31ec 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -142,11 +142,22 @@ Status XlaCompileOnDemandOp::Compile( TF_RETURN_IF_ERROR(ctx->allocate_temp( device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); Notification n; + Status status; ctx->op_device_context()->CopyDeviceTensorToCPU( &device_tensor, "ConstantArgument", reinterpret_cast(ctx->device()), &host_tensor, - [&](Status status) { n.Notify(); }); + [&](Status s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); + if (!status.ok()) { + LOG(ERROR) << "Copying tensor of shape " + << device_tensor.shape().DebugString() << " from " + << ctx->device()->name() << "to CPU failed with " + << status.ToString(); + return status; + } constant_arguments[i] = host_tensor; } } @@ -189,6 +200,7 @@ Status XlaCompileOnDemandOp::Compile( std::map variable_args = GetVariables(ctx); std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( constant_arguments, variable_args, ctx, &args)); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 7df898ad12a15345f45fc96e0ec3d42b6e51731b..94dc61d55fb047c0ea81d98fde24cb55387c27d7 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -63,7 +63,19 @@ Status XlaCpuDeviceFactory::CreateDevices( options.device_ordinal = 0; options.compilation_device_name = DEVICE_CPU_XLA_JIT; options.use_multiple_streams = false; - devices->push_back(absl::make_unique(session_options, options)); + auto device = absl::make_unique(session_options, options); + + // Setting GpuDeviceInfo because eager runtime relies on the device + // context in tensorflow_gpu_device_info(). Also, + // tensorflow_gpu_device_info() == nullptr is used as an IsCPU test. + // We need XlaCpuDevice to be treated not as CPU because it allocates + // XlaTensors, not regular Tensors. + Status status = device->UseGpuDeviceInfo(); + if (!status.ok()) { + errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT); + return status; + } + devices->push_back(std::move(device)); return Status::OK(); } @@ -71,9 +83,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { +constexpr std::array kAllXlaCpuTypes = { {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, - DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 4201ff91a89b1bee370e6a43337c51abe3bf974a..c67b4f11b030f22c123336327ff9fa67b1211d7a 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -201,7 +201,8 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, jit_device_name_(options.compilation_device_name), platform_(options.platform), use_multiple_streams_(options.use_multiple_streams), - shape_representation_fn_(options.shape_representation_fn) { + shape_representation_fn_(options.shape_representation_fn), + allowed_devices_(options.allowed_devices) { VLOG(1) << "Created XLA device " << options.compilation_device_name << " " << this; thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device", @@ -218,9 +219,6 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, XlaDevice::~XlaDevice() { VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; mutex_lock lock(mu_); - while (outstanding_asynchronous_operations_ > 0) { - outstanding_asynchronous_operations_cv_.wait(lock); - } if (device_context_) { device_context_->Unref(); } @@ -234,7 +232,8 @@ xla::LocalClient* XlaDevice::client() const { // TODO(b/78468222): This can fail, at least when the backend is GPU and // there is no GPU on the host. - return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie(); + return xla::ClientLibrary::GetOrCreateLocalClient(platform_, allowed_devices_) + .ValueOrDie(); } Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { @@ -396,12 +395,6 @@ Status XlaDevice::Sync() { if (!stream) return Status::OK(); Status status = stream->BlockHostUntilDone(); - { - mutex_lock lock(mu_); - while (outstanding_asynchronous_operations_ > 0) { - outstanding_asynchronous_operations_cv_.wait(lock); - } - } TF_RETURN_IF_ERROR(status); if (!stream->ok()) { return errors::Internal("XlaDevice::Sync() failed."); @@ -410,6 +403,8 @@ Status XlaDevice::Sync() { return Status::OK(); } +// TODO(b/112409994): This is no longer necessary. Consolidate it with the +// synchronous version. void XlaDevice::Sync(const DoneCallback& done) { VLOG(1) << "XlaDevice::Sync (asynchronous)"; std::shared_ptr stream; @@ -422,14 +417,20 @@ void XlaDevice::Sync(const DoneCallback& done) { return; } + // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at + // the end of the stream, after everything that has already been enqueued + // there at this moment. When the host callback is called, everything before + // it must have already finished, and the host callback will then place the + // task below onto a background thread. (See the implementation of + // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done + // callback is finally called from that background thread, we know for sure + // that everything enqueued onto the stream (i.e., the device) at this very + // moment--when ThenEnqueueOnBackgroundThread is called--will have finished. + // This achieves a device-wide sync. stream->ThenEnqueueOnBackgroundThread( [this, stream, done](se::StreamExecutor*) { tracing::ScopedActivity activity("XlaDevice::Sync::Callback", /*is_expensive=*/true); - mutex_lock lock(mu_); - while (outstanding_asynchronous_operations_ > 0) { - outstanding_asynchronous_operations_cv_.wait(lock); - } done(stream->ok() ? Status::OK() : errors::Internal("XlaDevice::Sync() failed.")); }); @@ -468,57 +469,27 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, return status; } -void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) { +void XlaDevice::SetAllowsSyncOnCompletion(bool sync_on_completion) { mutex_lock lock(mu_); sync_on_completion_ = sync_on_completion; } -bool XlaDevice::RequiresSyncOnCompletion() const { +bool XlaDevice::AllowsSyncOnCompletion() const { mutex_lock lock(mu_); return sync_on_completion_; } -XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( - XlaDevice* device) - : device_(device) { - mutex_lock lock(device_->mu_); - ++device_->outstanding_asynchronous_operations_; -} - -XlaDevice::AsynchronousOperationHandle::~AsynchronousOperationHandle() { - if (device_) { - mutex_lock lock(device_->mu_); - --device_->outstanding_asynchronous_operations_; - device_->outstanding_asynchronous_operations_cv_.notify_all(); +Status XlaDevice::RefreshStatus() { + std::shared_ptr stream; + { + mutex_lock lock(mu_); + stream = stream_; } -} - -XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( - const XlaDevice::AsynchronousOperationHandle& other) - : device_(other.device_) { - mutex_lock lock(device_->mu_); - ++device_->outstanding_asynchronous_operations_; -} - -XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( - XlaDevice::AsynchronousOperationHandle&& other) - : device_(other.device_) { - other.device_ = nullptr; -} - -XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: -operator=(const XlaDevice::AsynchronousOperationHandle& other) { - device_ = other.device_; - mutex_lock lock(device_->mu_); - ++device_->outstanding_asynchronous_operations_; - return *this; -} - -XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: -operator=(XlaDevice::AsynchronousOperationHandle&& other) { - device_ = other.device_; - other.device_ = nullptr; - return *this; + if (!stream) { + return Status::OK(); + } + // Stream status is XlaDevice status, no extra operations needed. + return stream->RefreshStatus(); } XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index c8bb276cdb9673fdcba4cc15a9f33ecd3ae96dbb..5fe1290fa03f2b1f9d90e36dbc5769b3c2728c8d 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -24,7 +24,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include +#include "absl/types/optional.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -123,6 +125,11 @@ class XlaDevice : public LocalDevice { // If padded_shape_fn is empty, a default implementation that returns // the logical on-device shape without padding is used. PaddedShapeFn padded_shape_fn; + + // Set of devices to use. This controls which of the devices on the given + // platform will have resources allocated. For GPUs this will be + // filled from visible_gpu_devices list from session configuration. + absl::optional> allowed_devices; }; // Creates a new XLA Device. @@ -160,35 +167,13 @@ class XlaDevice : public LocalDevice { Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); // Instructs this XlaDevice to return 'sync_on_completion' for - // RequiresSyncOnCompletion(). - void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); - - bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); - - // A simple RAII handle. On construction the device's - // outstanding_asynchronous_operations_ field is incremented; on destruction - // it is decremented. - class AsynchronousOperationHandle { - public: - AsynchronousOperationHandle(XlaDevice* device); - ~AsynchronousOperationHandle(); - AsynchronousOperationHandle(const AsynchronousOperationHandle& other); - AsynchronousOperationHandle(AsynchronousOperationHandle&& other); - AsynchronousOperationHandle& operator=( - const AsynchronousOperationHandle& other); - AsynchronousOperationHandle& operator=(AsynchronousOperationHandle&& other); + // AllowsSyncOnCompletion(). + void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); + bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); - private: - XlaDevice* device_ = nullptr; - }; - - AsynchronousOperationHandle CreateAsynchronousOperationHandle() { - return AsynchronousOperationHandle(this); - } + Status RefreshStatus() override LOCKS_EXCLUDED(mu_); private: - friend class AsynchronousOperationHandle; - xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -248,14 +233,14 @@ class XlaDevice : public LocalDevice { // Thread pool used for running closures std::unique_ptr thread_pool_; - // True if the device requires XlaDevice::Sync to be called on completion + // True if the device allows XlaDevice::Sync to be called on completion // regardless of status. - bool sync_on_completion_ GUARDED_BY(mu_) = false; + bool sync_on_completion_ GUARDED_BY(mu_) = true; - // Count of outstanding asynchronous operations which must be zero on Sync() - // completion. - int64 outstanding_asynchronous_operations_ GUARDED_BY(mu_) = 0; - condition_variable outstanding_asynchronous_operations_cv_; + // Set of devices to use. This controls which of the devices on the given + // platform will have resources allocated. For GPUs this will be + // filled from visible_gpu_devices list from session configuration. + absl::optional> allowed_devices_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 6e6532731e64bd42ee56aa719748988f321e0f17..28681bb8b03dbf97e8145972f9a04b5855fafdae 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -79,6 +79,13 @@ XlaDeviceContext::XlaDeviceContext( } } +void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, + Device* device, + Tensor* output_tensor, + StatusCallback done) const { + done(errors::Unimplemented("XLA->XLA same-device copies not implemented.")); +} + void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, @@ -124,7 +131,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, xla::ShapeUtil::MakeShape(shape.element_type(), xla::AsInt64Slice(shape.dimensions()))); - VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + VLOG(2) << "Transfer to device as literal: " << literal.ToString() << " " << xla_tensor->shaped_buffer().ToString(); if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow( @@ -207,7 +214,7 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, device_to_host_stream_.get(), xla_tensor->shaped_buffer(), literal, [ref, xla_tensor, done](xla::Status status) { done([&]() -> Status { - VLOG(1) << "Transfer from device as literal: " + VLOG(2) << "Transfer from device as literal: " << xla_tensor->shaped_buffer().ToString(); return status; }()); diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 1e18df197a2dd65590c5181b4dae4481dca36641..e45db989fac720df6c3458c93a6b8dbb0919f930 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -62,6 +62,9 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; xla::LocalClient* client() const { return client_; } se::Stream* stream() const { return stream_.get(); } diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index adf0f994b84d9fbf918a5b2478aa7d106853e038..927f983ba9ef23c8509523f42366c0c89c29db9f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -203,6 +203,8 @@ class XlaAssignVariableOp : public OpKernel { .HostMemory("output") \ .TypeConstraint("T"), \ ArgOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).TypeConstraint("T"), ArgOp); \ \ REGISTER_KERNEL_BUILDER(Name(kRetOp) \ .Device(DEVICE) \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 944f732b99c0924a08932eda0aedd8c815cc51d0..b29f6a009b9e9fdba76ac55386a4bec2f339cc0e 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -16,7 +16,10 @@ limitations under the License. // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "CUDA" (GPU) backend. +#include #include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -26,6 +29,30 @@ limitations under the License. namespace tensorflow { +// Returns a set containing the device ids contained in visible_device_list or +// nullopt if it is empty. It returns error in case of malformed configuration +// string. +static xla::StatusOr>> ParseVisibleDeviceList( + const string& visible_device_list) { + std::set gpu_ids; + if (visible_device_list.empty()) { + return {{absl::nullopt}}; + } + const std::vector visible_devices = + absl::StrSplit(visible_device_list, ','); + for (const string& platform_gpu_id_str : visible_devices) { + int32 platform_gpu_id; + if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) { + return errors::InvalidArgument( + "Could not parse entry in 'visible_device_list': '", + platform_gpu_id_str, + "'. visible_device_list = ", visible_device_list); + } + gpu_ids.insert(platform_gpu_id); + } + return {{gpu_ids}}; +} + class XlaGpuDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, @@ -52,8 +79,18 @@ Status XlaGpuDeviceFactory::CreateDevices( VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); return Status::OK(); } - - for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) { + string allowed_gpus = + session_options.config.gpu_options().visible_device_list(); + absl::optional> gpu_ids = + ParseVisibleDeviceList(allowed_gpus).ValueOrDie(); + if (!gpu_ids) { + gpu_ids.emplace(); + // Fill the gpu_ids set with all devices if config string is empty. + for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) { + gpu_ids->insert(i); + } + } + for (int i : *gpu_ids) { XlaDevice::Options options; options.platform = platform.ValueOrDie(); options.device_name_prefix = name_prefix; @@ -61,6 +98,7 @@ Status XlaGpuDeviceFactory::CreateDevices( options.device_ordinal = i; options.compilation_device_name = DEVICE_GPU_XLA_JIT; options.use_multiple_streams = true; + options.allowed_devices = gpu_ids; auto device = absl::make_unique(session_options, options); Status status = device->UseGpuDeviceInfo(); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 4007309ed1c57b663dca5bac0df11260bf1327f3..e1a582406153d2af447fa9d4ebcaf0bf0842b132 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -26,9 +26,9 @@ namespace tensorflow { const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; -constexpr std::array kExecAllTypes = { +constexpr std::array kExecAllTypes = { {DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_BOOL, DT_BFLOAT16}}; + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; class XlaInterpreterDeviceFactory : public DeviceFactory { public: diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 3b0bda4caa161a7561a3098b89420329998ff8a7..c64981053fad2dbf1e8bcd623a940ded8b4d9150 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -237,7 +237,7 @@ void XlaComputationLaunchContext::PopulateInputs( const xla::Shape on_device_shape = client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); - if (xla::ShapeUtil::IsTuple(on_device_shape)) { + if (on_device_shape.IsTuple()) { const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); arg_ptrs_[i] = const_cast(&xla_tensor->shaped_buffer()); @@ -274,7 +274,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( // If the on-host-shape isn't a tuple, create a new single-element tuple // buffer with a nullptr root index table. This allows the code below to treat // output as a tuple unconditionally. - if (!xla::ShapeUtil::IsTuple(output.on_host_shape())) { + if (!output.on_host_shape().IsTuple()) { ShapedBuffer nontuple_buffer = output.release(); ShapedBuffer buffer( xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}), @@ -377,7 +377,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( } if (VLOG_IS_ON(3)) { - VLOG(3) << ctx->mutable_output(i)->DebugString(); + VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString(); } } diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 437db019a0eabe66417725148d8b121842e90479..554227f09de0ab4d9e07f199b957657f3121ff06 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -199,19 +199,17 @@ class XlaTensorBuffer : public TensorBuffer { public: XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size, Allocator* allocator) - : expected_size_(expected_size), + : TensorBuffer(const_cast(ptr)), + expected_size_(expected_size), actual_size_(actual_size), - allocator_(allocator) { - data_ = const_cast(ptr); - } + allocator_(allocator) {} ~XlaTensorBuffer() override { - if (data_) { - allocator_->DeallocateRaw(data_); + if (data()) { + allocator_->DeallocateRaw(data()); } } - void* data() const override { return data_; } size_t size() const override { return expected_size_; } TensorBuffer* root_buffer() override { return this; } @@ -231,7 +229,6 @@ class XlaTensorBuffer : public TensorBuffer { } private: - void* data_; size_t expected_size_; size_t actual_size_; Allocator* allocator_; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index bc3d60b90e58b4018f1c52b09941dedba7ef348a..9b6ca4092c3177ac26503add13bce25d2c0bb820 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -72,7 +72,7 @@ py_test( tf_xla_py_test( name = "adadelta_test", - size = "large", + size = "medium", srcs = ["adadelta_test.py"], deps = [ ":xla_test", @@ -230,6 +230,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", + "//tensorflow/python:standard_ops", ], ) @@ -242,6 +243,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:framework", + "//tensorflow/python:map_fn", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -277,10 +279,9 @@ tf_xla_py_test( ], ) -# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors tf_xla_py_test( name = "concat_ops_test", - size = "large", + size = "medium", srcs = ["concat_ops_test.py"], deps = [ ":xla_test", @@ -406,15 +407,8 @@ tf_xla_py_test( tf_xla_py_test( name = "eager_test", - size = "large", + size = "medium", srcs = ["eager_test.py"], - disabled_backends = [ - # TODO(b/78199195) Support XLA CPU devices in eager runtime - "cpu", - "cpu_ondemand", - # TODO(b/78468222) Enable GPU backend - "gpu", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -684,6 +678,7 @@ tf_xla_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", + "//tensorflow/python:standard_ops", ], ) @@ -833,6 +828,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python:standard_ops", "//tensorflow/python:stateless_random_ops", ], ) @@ -1195,11 +1191,18 @@ tf_xla_py_test( tf_xla_py_test( name = "quantized_ops_test", - size = "small", + size = "medium", srcs = ["quantized_ops_test.py"], + disabled_backends = [ + "cpu", + "cpu_ondemand", + ], deps = [ ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", + "//tensorflow/python:bitwise_ops", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py index b7b7fda293b69d6f0cec61d0d234277636a3670d..6cf16cc07ff503c4f3e008cfb720224abe5e9166 100644 --- a/tensorflow/compiler/tests/adadelta_test.py +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -32,10 +32,18 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase): def testBasic(self): num_updates = 4 # number of ADADELTA steps to perform + if "CPU" in self.device: + # To avoid timeout on CPU. + all_grad = [0.2, 0.01] + all_lr = [1.0, 0.1] + else: + all_grad = [0.2, 0.1, 0.01] + all_lr = [1.0, 0.5, 0.1] + for dtype in self.float_types: with self.cached_session(), self.test_scope(): - for grad in [0.2, 0.1, 0.01]: - for lr in [1.0, 0.5, 0.1]: + for grad in all_grad: + for lr in all_lr: var0_init = [1.0, 2.0] var1_init = [3.0, 4.0] var0 = resource_variable_ops.ResourceVariable( diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 9a5423c1b2a5df7880453cbb328f6a8174066255..c829c50b5518b29c96c0b0117a6cd143911bd1fc 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -311,6 +311,30 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) + if dtype in [np.float32, np.float64]: + x = np.array([ + -0.0, 0.0, -0.0, +0.0, np.inf, np.inf, -np.inf, -np.inf, 2.0, 2.0, + 1.0 + ], + dtype=dtype) + y = np.array( + [-0.0, 0.0, +0.0, -0.0, 1.0, -1.0, 1.0, -1.0, 2.0, 1.0, 2.0], + dtype=dtype) + expected = np.nextafter(x, y) + + # We use assertAllEqual to expose any bugs hidden by relative or + # absolute error tolerances. + def NextAfterEqualityTest(result, expected, rtol): + del rtol + return self.assertAllEqual(result, expected) + + self._testBinary( + math_ops.nextafter, + x, + y, + expected=expected, + equality_test=NextAfterEqualityTest) + # min/max not supported for complex if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( @@ -400,7 +424,7 @@ class BinaryOpsTest(xla_test.XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - ctypes = {np.complex64: np.float32} + ctypes = {np.complex64: np.float32, np.complex128: np.float64} self._testBinary( math_ops.complex, np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]), diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 447a7de2cb6526a5dcf7789d4f2bffb5e733e8c0..ed580f95b6c2f57dfdf46cfcd64cabb452980c5d 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -5,6 +5,7 @@ load("//tensorflow/compiler/tests:plugin.bzl", "plugins") load( "//tensorflow/core:platform/default/build_config_root.bzl", "tf_cuda_tests_tags", + "tf_exec_compatible_with", ) def all_backends(): @@ -64,7 +65,7 @@ def tf_xla_py_test( if backend == "cpu": backend_args += [ "--test_device=XLA_CPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_COMPLEX128", ] elif backend == "gpu": backend_args += [ @@ -84,6 +85,7 @@ def tf_xla_py_test( else: fail("Unknown backend {}".format(backend)) + test_tags = tags + backend_tags native.py_test( name = test_name, srcs = srcs, @@ -92,7 +94,8 @@ def tf_xla_py_test( main = "{}.py".format(name) if main == None else main, data = data + backend_data, deps = deps + backend_deps, - tags = tags + backend_tags, + tags = test_tags, + exec_compatible_with = tf_exec_compatible_with({"tags": test_tags}), **kwargs ) test_names.append(test_name) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 2187f57960f80300d631bdc7eb8fe5e9c8dddeea..76750decd2963ea12680a46d7340f48e8b011fa9 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -294,6 +294,9 @@ class ConcatTest(xla_test.XLATestCase): # The purpose of this is to ensure that XLA on GPU will not run out of memory # with too many arguments. def testConcatLargeNumberOfTensors(self): + if "CPU" in self.device: + self.skipTest("This test can time out on CPU, so we will just allow " + "other backends to catch this specific error.") with self.cached_session(): with self.test_scope(): for concat_dim in range(2): diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index bf5ea7b1fb6fb3c774c4db20d059f131990d20d3..b7d08df9f7d144b71fd0b09535e10b8f596ea6ca 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -72,7 +72,7 @@ class DenseLayerTest(test.TestCase): x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -97,7 +97,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -126,7 +126,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - self.evaluate(variables.initialize_all_variables()) + self.evaluate(variables.global_variables_initializer()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 174bfa9efbcd7dcb4f895237eb01c17bc4a3a6b4..90146e6b27ca31304a2549ec247412341efe390c 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -350,8 +350,13 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding) - def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes, - stride, padding): + def _CompareBackpropFilter(self, + input_sizes, + filter_sizes, + output_sizes, + stride, + padding, + data_format="NHWC"): x0 = np.random.rand(*input_sizes).astype(np.float32) x2 = np.random.rand(*output_sizes).astype(np.float32) @@ -360,13 +365,30 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): t0 = array_ops.placeholder(np.float32, shape=input_sizes) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t2 = array_ops.placeholder(np.float32, shape=output_sizes) + native_t0 = t0 + native_t2 = t2 + strides = [1, stride, stride, 1] + if use_xla: + if data_format == "NCHW": + # Transpose from NWHC input to NCHW + # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] + native_t0 = array_ops.transpose(t0, [0, 3, 1, 2]) + native_t2 = array_ops.transpose(t2, [0, 3, 1, 2]) + strides = [1, 1, stride, stride] with self.test_scope(): backprop = nn_ops.depthwise_conv2d_native_backprop_filter( - t0, t1, t2, strides=[1, stride, stride, 1], padding=padding) + native_t0, + t1, + native_t2, + strides=strides, + padding=padding, + data_format=data_format) else: + # For CPU, the format NCHW is not supported. Therefore we always use + # NHWC here. backprop = nn_ops.depthwise_conv2d_native_backprop_filter( - t0, t1, t2, strides=[1, stride, stride, 1], padding=padding) + native_t0, t1, native_t2, strides=strides, padding=padding) ret = backprop.eval({t0: x0, t2: x2}) self.assertShapeEqual(ret, backprop) return ret @@ -379,11 +401,24 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): for index, (input_size, filter_size, output_size, stride, padding) in enumerate(ConfigsToTest()): print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:", - input_size, "*", filter_size, "stride:", stride, "padding:", - padding) + input_size, "*", filter_size, "producing output", output_size, + "stride:", stride, "padding:", padding) self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding) + def testDepthwiseConv2DFilterGradFormatNCHWCompare(self): + for index, (input_size, filter_size, output_size, stride, + padding) in enumerate(ConfigsToTest()): + print("Testing DepthwiseConv2DFilterGradFormatNCHWCompare,", index, + "th config:", input_size, "*", filter_size, "producing output", + output_size, "stride:", stride, "padding:", padding) + self._CompareBackpropFilter( + input_size, + filter_size, + output_size, + stride, + padding, + data_format="NCHW") if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 2af32b537ba53723370faf81aebf308a465718c7..632eccbb097b4e84f10f926e89d7fa439c8a38cd 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -31,7 +32,9 @@ from tensorflow.python.framework import ops from tensorflow.python.layers import convolutional from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -463,7 +466,7 @@ class EagerFunctionTest(xla_test.XLATestCase): def f(x, y): return x[0::2, y:, ...] - x = array_ops.ones([2, 3, 4]) + x = array_ops.ones([2, 3, 4], dtype=dtypes.float32) y = array_ops.ones([], dtype=dtypes.int32) with backprop.GradientTape() as tape: tape.watch(x) @@ -479,15 +482,15 @@ class EagerFunctionTest(xla_test.XLATestCase): @function.defun def times_two(x): - return 2 * x + return 2. * x @function.defun def two_x_plus_1(x): - return times_two(x) + 1 + return times_two(x) + 1. - x = constant_op.constant([2, 3, 4]) + x = constant_op.constant([2., 3., 4.]) y = two_x_plus_1(x) - self.assertAllEqual([5, 7, 9], y.numpy()) + self.assertAllEqual([5., 7., 9.], y.numpy()) def testNestedDefunWithVariable(self): with self.test_scope(): @@ -506,7 +509,7 @@ class EagerFunctionTest(xla_test.XLATestCase): x = constant_op.constant(3.0) y = f(x) - self.assertEqual(75, y.numpy()) + self.assertEqual(75.0, y.numpy()) def testNestedDefunInGradientTape(self): with self.test_scope(): @@ -555,6 +558,71 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertEqual(9, dy_v0.numpy()) self.assertEqual(15, dy_v1.numpy()) + def testWhileInDefun(self): + with self.test_scope(): + @def_function.function + def f(start): + c = lambda x: math_ops.less(x, 13.0) + b = lambda x: math_ops.add(x, 1.0) + return control_flow_ops.while_loop(c, b, [start]) + + y = f(constant_op.constant(3.0)) + self.assertEqual(13.0, y.numpy()) + + def testAutoGraphWhileInDefun(self): + with self.test_scope(): + @def_function.function + def f(start): + x = start + while x < 13.0: + x += 1.0 + return x + + y = f(constant_op.constant(3.0)) + self.assertEqual(13.0, y.numpy()) + + def testCondInDefun(self): + with self.test_scope(): + @def_function.function + def f(pred, value): + fn1 = lambda: math_ops.add(value, 1.0) + fn2 = lambda: math_ops.subtract(value, 1.0) + return control_flow_ops.cond(pred, fn1, fn2) + + plus_one = f(constant_op.constant(True), constant_op.constant(10.0)) + minus_one = f(constant_op.constant(False), constant_op.constant(10.0)) + self.assertEqual(11.0, plus_one.numpy()) + self.assertEqual(9.0, minus_one.numpy()) + + def testAutoGraphCondInDefun(self): + with self.test_scope(): + @def_function.function + def f(pred, value): + if pred: + return value + 1.0 + else: + return value - 1.0 + + plus_one = f(constant_op.constant(True), constant_op.constant(10.0)) + minus_one = f(constant_op.constant(False), constant_op.constant(10.0)) + self.assertEqual(11.0, plus_one.numpy()) + self.assertEqual(9.0, minus_one.numpy()) + + def testScanInDefun(self): + with self.test_scope(): + elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='data') + v = constant_op.constant(2.0, name='v') + + @def_function.function + def f(y): + # pylint: disable=unnecessary-lambda + return functional_ops.scan( + lambda a, x: math_ops.multiply(a, x), y, initializer=v) + # pylint: enable=unnecessary-lambda + + r = f(elems) + self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) + class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 374942a0b339b816944ea5529e4f84134b60017b..56a8e1b1667f154f6cec475ee0f4f8b308121c09 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -191,6 +191,20 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 + + # The TensorFlow FusedBatchNormGrad training operation takes two inputs with + # implementation defined values. In theory the only correct value these + # inputs are the corresponding reserve_space_{1|2} outputs from the + # FusedBatchNorm training operation. However, in practice, we rely on the + # first one being mean on {C|G}PU, and the second one being variance on CPU + # and inverse(sqrt(variance + epsilon)) on GPU (we test this assumption + # separately). + reserve_space_1_val = mean_val + if self.device == "XLA_GPU": + reserve_space_2_val = np.reciprocal(np.sqrt(var_val + epsilon)) + else: + reserve_space_2_val = var_val + data_format_src = "NHWC" grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) @@ -207,18 +221,26 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): np.float32, shape=x_val_converted.shape, name="grad") x = array_ops.placeholder( np.float32, shape=x_val_converted.shape, name="x") - mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") - var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") + reserve_space_1 = array_ops.placeholder( + np.float32, shape=scale_shape, name="reserve_space_1") + reserve_space_2 = array_ops.placeholder( + np.float32, shape=scale_shape, name="reserve_space_2") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format=data_format, is_training=True) + grad, + x, + scale, + reserve_space_1, + reserve_space_2, + data_format=data_format, + is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { grad: grad_val_converted, x: x_val_converted, - mean: mean_val, - var: var_val, + reserve_space_1: reserve_space_1_val, + reserve_space_2: reserve_space_2_val, scale: scale_val }) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 0e2d840418156d825e2d141018e49f42374c8fee..42e688174fce9e939feb09e1767ebab31e30a6ee 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -403,6 +403,117 @@ class AdjustSaturationTest(xla_test.XLATestCase): self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) +class ResizeNearestNeighborTest(xla_test.XLATestCase): + # TODO(ilch): Wrap each test with `for dtype in self.float_types:` + # Some work to understand how that should be done was presented here: + # cl/227850213 + + def _assertForwardOpMatchesExpected(self, + image_np, + target_shape, + expected=None, + large_tolerance=False, + align_corners=True): + if expected is None: + self.fail("expected must be specified") + with self.cached_session() as sess, self.test_scope(): + image = array_ops.placeholder(image_np.dtype) + resized = gen_image_ops.resize_nearest_neighbor( + image, target_shape, align_corners=align_corners) + out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) + if large_tolerance: + self.assertAllClose( + expected[np.newaxis, :, :, np.newaxis], out, rtol=2e-4, atol=2e-4) + else: + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def testAlignCorners2x2To1x1(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=np.float32), [1, 1], + expected=np.array([[1]], dtype=np.float32)) + + def testAlignCorners1x1To2x2(self): + self._assertForwardOpMatchesExpected( + np.array([[1]], dtype=np.float32), [2, 2], + expected=np.array([[1, 1], [1, 1]], dtype=np.float32)) + + def testAlignCorners1x1To3x3(self): + self._assertForwardOpMatchesExpected( + np.array([[1]], dtype=np.float32), [3, 3], + expected=np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=np.float32)) + + def testAlignCorners2x2To3x3(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=np.float32), [3, 3], + expected=np.array([[1, 2, 2], [3, 4, 4], [3, 4, 4]], dtype=np.float32)) + + def testAlignCorners2x2To4x4(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=np.float32), [4, 4], + expected=np.array( + [[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]], + dtype=np.float32), large_tolerance=True) + + def testAlignCorners3x3To2x2(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [2, 2], + expected=np.array([[1, 3], [7, 9]], dtype=np.float32)) + + def testAlignCorners4x4To3x3(self): + self._assertForwardOpMatchesExpected( + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=np.float32), [3, 3], + expected=np.array([[1, 3, 4], [9, 11, 12], [13, 15, 16]], + dtype=np.float32)) + + def testAlignCorners3x3To4x4(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [4, 4], + expected=np.array( + [[1, 2, 2, 3], [4, 5, 5, 6], [4, 5, 5, 6], [7, 8, 8, 9]], + dtype=np.float32)) + + def testAlignCorners3x3To6x6(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [6, 6], + expected=np.array( + [[1, 1, 2, 2, 3, 3], [1, 1, 2, 2, 3, 3], [4, 4, 5, 5, 6, 6], + [4, 4, 5, 5, 6, 6], [7, 7, 8, 8, 9, 9], [7, 7, 8, 8, 9, 9]], + dtype=np.float32)) + + def testAlignCorners3x3To9x9(self): + # The expected matrix might look uneven in terms of how many of each number + # there is, but this is an artifact of doing the dilation and convolution + # iteratively. The behavior is less esoteric in the 3x3To12x12 case below. + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [9, 9], + expected=np.array( + [[1, 2, 2, 2, 2, 3, 3, 3, 3], [4, 5, 5, 5, 5, 6, 6, 6, 6], + [4, 5, 5, 5, 5, 6, 6, 6, 6], [4, 5, 5, 5, 5, 6, 6, 6, 6], + [4, 5, 5, 5, 5, 6, 6, 6, 6], [7, 8, 8, 8, 8, 9, 9, 9, 9], + [7, 8, 8, 8, 8, 9, 9, 9, 9], [7, 8, 8, 8, 8, 9, 9, 9, 9], + [7, 8, 8, 8, 8, 9, 9, 9, 9]], + dtype=np.float32)) + + def testAlignCorners3x3To12x12(self): + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), [12, 12], + expected=np.array([[1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9]], + dtype=np.float32)) + + class ResizeBilinearTest(xla_test.XLATestCase): def _assertForwardOpMatchesExpected(self, @@ -444,14 +555,14 @@ class ResizeBilinearTest(xla_test.XLATestCase): self.assertAllCloseAccordingToType(expected[np.newaxis, :, :, np.newaxis], out) - def testAlignCorners1x2To3x2(self): + def testAlignCorners1x2To3x3(self): for dtype in self.float_types: self._assertForwardOpMatchesExpected( np.array([[1, 2]], dtype=dtype), [3, 3], expected=np.array([[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32)) - def testAlignCorners1x2To3x2Grad(self): + def testAlignCorners1x2To3x3Grad(self): for dtype in self.float_types: self._assertBackwardOpMatchesExpected( np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py index 80c338513bc9ff6b8e56c5ad6b904af9e06a3715..cd9b728ab314d29e4eb585e00a9131024ea3a207 100644 --- a/tensorflow/compiler/tests/quantized_ops_test.py +++ b/tensorflow/compiler/tests/quantized_ops_test.py @@ -18,11 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest @@ -44,5 +49,55 @@ class QuantizedOpsTest(xla_test.XLATestCase): self.assertAllEqual(value, expected) +class DeuantizedOpsTest(xla_test.XLATestCase): + + def pack_uint8_r2_to_uint32(self, test_input): + num_rows, num_columns = test_input.get_shape().as_list() + num_output_columns = int(math.ceil(num_columns / 4.0)) + padding_input = array_ops.pad( + math_ops.cast(test_input, dtype=dtypes.uint8), + constant_op.constant([[ + 0, + 0, + ], [0, num_output_columns * 4 - num_columns]])) + output = array_ops.zeros([num_rows, num_output_columns], + dtype=dtypes.uint32) + num_elements_per_pack = 4 + shift_bits = 8 + + iota_r1 = math_ops.range(num_output_columns * num_elements_per_pack) + + for p in range(num_elements_per_pack): + selected_index = math_ops.equal( + math_ops.mod(iota_r1, num_elements_per_pack), p) + gather_index = array_ops.boolean_mask(iota_r1, selected_index) + gathered_input = array_ops.gather(padding_input, gather_index, axis=1) + total_shift_bits = shift_bits * (num_elements_per_pack - p - 1) + left_shift_input = bitwise_ops.left_shift( + math_ops.cast(gathered_input, dtype=dtypes.uint32), total_shift_bits) + output = bitwise_ops.bitwise_or(output, left_shift_input) + return output + + def testDequantizeQuint8(self): + num_rows = 100 + num_columns = 3547 + random_input = np.random.normal(128.0, 10.0, [num_rows, num_columns]) + with self.cached_session() as session: + with ops.device("CPU"): + test_input = ops.convert_to_tensor(random_input, dtype=dtypes.float32) + transposed_input = array_ops.transpose(test_input, [1, 0]) + quantized_input = array_ops.quantize(transposed_input, 0.0, 255.0, + dtypes.quint8) + packed_input = self.pack_uint8_r2_to_uint32(quantized_input.output) + with self.test_scope(): + transposed_quantized_output = xla.dequantize(packed_input, 0.0, 255.0, + "MIN_COMBINED", True) + quantized_output = array_ops.slice(transposed_quantized_output, [0, 0], + [num_rows, num_columns]) + + value = session.run(quantized_output) + self.assertAllClose(value, random_input, 1.0) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index d23fd125163d1afe8c7fd5e008d4b617ff4b2874..1521cc760b85b176acb27c1489640e92ef90e247 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -63,6 +63,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -80,6 +81,7 @@ int64 tf_xla_random_seed = 0; int32 tf_xla_test_repetitions = 20; int64 tf_xla_max_tensor_size = 10000LL; string* tf_xla_test_device_ptr; // initial value set in main() +string* tf_xla_reference_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; string LocalDeviceToFullDeviceName(const string& device) { @@ -321,6 +323,9 @@ class OpTest : public ::testing::Test { // for use as reduction indices. Tensor RandomReductionIndices(int rank); + // Returns a random bit. + bool RandomBool(); + struct WindowedSpatialDims { Padding padding; std::vector kernel_dims; @@ -453,6 +458,11 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, return dims; } +bool OpTest::RandomBool() { + std::bernoulli_distribution d(0.5); + return d(generator()); +} + Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, absl::Span shape) { Tensor tensor(dtype, TensorShape(shape)); @@ -760,8 +770,22 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { for (int i = 0; i < Tx.size(); ++i) { if (Tx(i) != Ty(i)) { return errors::InvalidArgument(absl::StrCat( - i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i), - ". x = ", x.DebugString(), "y = ", y.DebugString())); + i, "-th tensor element isn't equal: ", Str(Tx(i)), " vs. ", + Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString())); + } + } + return Status::OK(); +} + +Status TensorsAreEqualImplBfloat16(const Tensor& x, const Tensor& y) { + auto Tx = x.flat(); + auto Ty = y.flat(); + for (int i = 0; i < Tx.size(); ++i) { + if (Tx(i) != Ty(i)) { + return errors::InvalidArgument(absl::StrCat( + i, "-th tensor element isn't equal: ", static_cast(Tx(i)), + " vs. ", static_cast(Ty(i)), ". x = ", x.DebugString(), + "y = ", y.DebugString())); } } return Status::OK(); @@ -797,6 +821,8 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, return TensorsAreEqualImpl(a, b); case DT_BOOL: return TensorsAreEqualImpl(a, b); + case DT_BFLOAT16: + return TensorsAreEqualImplBfloat16(a, b); default: LOG(FATAL) << "Unexpected type : " << DataTypeString(a.dtype()); } @@ -829,8 +855,8 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( VLOG(1) << "Input: " << input_tensors.back().DebugString(); } - string cpu_device = - LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0")); + string reference_device = + LocalDeviceToFullDeviceName(*tf_xla_reference_device_ptr); string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; @@ -845,9 +871,9 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( std::vector expected_inputs, test_inputs; std::vector expected_fetches, test_fetches; Status status = builder.BuildGraph( - absl::StrCat("test", num_tests_, "_expected"), cpu_device, - /* use_jit= */ false, &graph, /* test_node_def= */ nullptr, - &expected_inputs, &expected_fetches); + absl::StrCat("test", num_tests_, "_expected"), reference_device, + /*use_jit=*/false, &graph, /*test_node_def=*/nullptr, &expected_inputs, + &expected_fetches); if (!status.ok()) { LOG(ERROR) << "Expected graph construction failed: " << status; return kFatalError; @@ -1371,6 +1397,19 @@ TEST_F(OpTest, Cast) { }); } +TEST_F(OpTest, CastBF16) { + Repeatedly([this]() { + DataType src_type, dst_type; + src_type = Choose({DT_FLOAT}); + dst_type = Choose({DT_BFLOAT16}); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast") + .RandomInput(src_type) + .Attr("SrcT", src_type) + .Attr("DstT", dst_type) + .Attr("Truncate", true)); + }); +} + TEST_F(OpTest, Ceil) { Repeatedly([this]() { return ExpectTfAndXlaOutputsAreClose( @@ -3346,11 +3385,41 @@ TEST_F(OpTest, ZerosLike) { }); } +// Example failing run: +// --tf_xla_reference_device=GPU:0 +// --tf_xla_test_use_jit=true --tf_xla_test_device=GPU:0 +// --tf_xla_test_repetitions=2 +// --gunit_filter='OpTest.FusedBatchNormTraining' +// --tf_xla_random_seed=2838146746 +TEST_F(OpTest, FusedBatchNormTraining) { + bool is_nhwc = RandomBool(); + std::vector x_dims = RandomDims(/*min_rank=*/4, /*max_rank=*/4, + /*min_size=*/5, /*max_size=*/20); + std::vector scale_dims = {x_dims[is_nhwc ? 3 : 1]}; + std::vector offset_dims = {x_dims[is_nhwc ? 3 : 1]}; + std::vector mean_dims = {0}; + std::vector variance_dims = {0}; + DataType type = DT_FLOAT; + Repeatedly([&] { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FusedBatchNorm") + .RandomInput(type, x_dims) + .RandomInput(type, scale_dims) + .RandomInput(type, offset_dims) + .RandomInput(type, mean_dims) + .RandomInput(type, variance_dims) + .Attr("T", type) + .Attr("data_format", is_nhwc ? "NHWC" : "NCHW") + .Attr("epsilon", static_cast(1.001e-05)) + .Attr("is_training", true)); + }); +} } // anonymous namespace } // namespace tensorflow int main(int argc, char** argv) { tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU:0"); + tensorflow::tf_xla_reference_device_ptr = new tensorflow::string("CPU:0"); std::vector flag_list = { tensorflow::Flag( "tf_xla_random_seed", &tensorflow::tf_xla_random_seed, @@ -3366,6 +3435,9 @@ int main(int argc, char** argv) { "Maximum number of elements for random input tensors."), tensorflow::Flag("tf_xla_test_device", tensorflow::tf_xla_test_device_ptr, "Tensorflow device type to use for test"), + tensorflow::Flag("tf_xla_reference_device", + tensorflow::tf_xla_reference_device_ptr, + "Tensorflow device type to use for reference"), tensorflow::Flag("tf_xla_test_use_jit", &tensorflow::tf_xla_test_use_jit, "Use JIT compilation for the operator under test"), }; diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 5c079d595c440cac644f5461154509abe7b1d1ed..47e0f384a4f1e46ccc35584aaff3a0aceff8a985 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -23,24 +23,20 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -def scalar_shape(): - return ops.convert_to_tensor([], dtype=dtypes.int32) - - class ListOpsTest(xla_test.XLATestCase): def testElementShape(self): with self.cached_session() as sess, self.test_scope(): dim = array_ops.placeholder(dtypes.int32) - l = list_ops.tensor_list_reserve( - element_shape=(dim, 15), num_elements=20, - element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(dim, 15), + element_dtype=dtypes.float32, + max_num_elements=20) e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32) e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64) self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15)) @@ -48,25 +44,44 @@ class ListOpsTest(xla_test.XLATestCase): def testPushPop(self): with self.cached_session() as sess, self.test_scope(): - num = array_ops.placeholder(dtypes.int32) - l = list_ops.tensor_list_reserve( - element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(7, 15), + element_dtype=dtypes.float32, + max_num_elements=10) l = list_ops.tensor_list_push_back( l, constant_op.constant(1.0, shape=(7, 15))) l = list_ops.tensor_list_push_back( l, constant_op.constant(2.0, shape=(7, 15))) l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) - self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15))) - self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e2), 2.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e1), 1.0 * np.ones((7, 15))) + + def testDoNotConstantFoldVariants(self): + with self.cached_session() as sess, self.test_scope(): + val = array_ops.placeholder(dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=(7, 15), + element_dtype=dtypes.float32, + max_num_elements=10) + # Note: Pushing a Placeholder will force the constant folding code + # to build a Const node with a DT_VARIANT output. This tests that XLA + # passes a cf_consider_fn which prevent folding such nodes. + l = list_ops.tensor_list_push_back( + l, array_ops.fill(value=val, dims=(7, 15))) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(2.0, shape=(7, 15))) + l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e2, {val: 1.0}), 2.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones((7, 15))) def testPushPopSeparateLists(self): with self.cached_session() as sess, self.test_scope(): - num = array_ops.placeholder(dtypes.int32) - l = list_ops.tensor_list_reserve( - element_shape=scalar_shape(), - num_elements=num, - element_dtype=dtypes.float32) + l = list_ops.empty_tensor_list( + element_shape=[], + element_dtype=dtypes.float32, + max_num_elements=20) l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) @@ -75,22 +90,95 @@ class ListOpsTest(xla_test.XLATestCase): l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) - result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20}) + result = sess.run([e11, [e21, e22], [e31, e32]]) self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]]) - def testEmptyTensorList(self): - dim = 7 + def testEmptyTensorListNoMax(self): with self.cached_session() as sess, self.test_scope(): - p = array_ops.placeholder(dtypes.int32) l = list_ops.empty_tensor_list( - element_shape=(p, 15), element_dtype=dtypes.float32) + element_shape=(7, 15), element_dtype=dtypes.float32) l = list_ops.tensor_list_push_back( - l, constant_op.constant(1.0, shape=(dim, 15))) + l, constant_op.constant(1.0, shape=(7, 15))) _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Use TensorListReserve instead"): - self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15))) + "Set the max number of elements"): + self.assertEqual(sess.run(e), 1.0 * np.ones((7, 15))) + def testEmptyTensorListMax(self): + with self.cached_session() as sess, self.test_scope(): + l = list_ops.empty_tensor_list( + element_shape=(10, 15), element_dtype=dtypes.float32, + max_num_elements=2) + l = list_ops.tensor_list_push_back( + l, array_ops.fill(value=3.0, dims=(10, 15))) + _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15))) + + def testListFromTensor(self): + with self.cached_session(), self.test_scope(): + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor(t, element_shape=[]) + e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e, 1.0) + l, e0 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 2.0) + l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(e1, 1.0) + self.assertAllEqual(list_ops.tensor_list_length(l), 0) + + def testGetSet(self): + with self.cached_session(), self.test_scope(): + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor(t, element_shape=[]) + e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 1.0) + l = list_ops.tensor_list_set_item(l, 0, 3.0) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [3.0, 2.0]) + + def testGetSetReserved(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=[], num_elements=2) + e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e0, 0.0) + l = list_ops.tensor_list_set_item(l, 0, 3.0) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [3.0, 0.0]) + + def testGetSetReservedNonScalar(self): + with self.cached_session() as sess, self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, + element_shape=(7, 15), + num_elements=2) + l = list_ops.tensor_list_set_item( + l, 0, constant_op.constant(1.0, shape=(7, 15))) + e1 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + e2 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e1), np.ones((7, 15))) + self.assertAllEqual(sess.run(e2), np.zeros((7, 15))) + + def testStack(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=[], + max_num_elements=2) + l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) + e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32) + self.assertAllEqual(e, 1.0) + l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t.shape.as_list(), [None]) + self.assertAllEqual(t, [1.0, 2.0]) + + def testStackWithUninitializedTensors(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=[], num_elements=3) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [0., 0., 0.]) if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 95c9e7ffd4651642781143c2c1940b0e51e1e470..083e2e58ae02b1aa383da76aebfca60fac59b84b 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -647,7 +647,7 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) - ctypes = {np.complex64: np.float32} + ctypes = {np.complex64: np.float32, np.complex128: np.float64} self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[3 - 4j, -1j, np.inf]], dtype=dtype), @@ -760,7 +760,7 @@ class UnaryOpsTest(xla_test.XLATestCase): lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"), np.array( [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32), - expected=np.array([10., 26.], dtype=np.float32)) + expected=np.array([14., 22.], dtype=np.float32)) def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index fcd7ac5ba1ca5049246e93e6f5f76746fb28c6b8..18c5870e0decb686f4df1c16bbb4a340c93ad21d 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -485,7 +485,7 @@ class SliceAssignTest(xla_test.XLATestCase): checker2[None] = [6] # new axis def testUninitialized(self): - with self.assertRaisesRegexp(errors.InvalidArgumentError, + with self.assertRaisesRegexp(errors.FailedPreconditionError, "uninitialized variable"): with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable([1, 2]) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 4cf88fc523735cc2d22e085afb83790c7ebb48e4..28274ff799de2c85e1e80512cadbe0206cb640a4 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -319,7 +319,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): session.run(output) self.assertRegexpMatches( invalid_arg_error.exception.message, - (r'^start_indices must be a vector with length equal to input rank, ' + (r'start_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and start_indices has shape \[2\].*')) def testDynamicSliceWithIncorrectSizeIndicesShape(self): @@ -332,7 +332,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): session.run(output) self.assertRegexpMatches( invalid_arg_error.exception.message, - (r'^size_indices must be a vector with length equal to input rank, ' + (r'size_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and size_indices has shape \[2\].*')) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..00d3c8cc5f610ea5c308fa7df49d963c78919d63 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -0,0 +1,445 @@ +# Description: +# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow +# and provide TensorRT operators and converter package. +# APIs are meant to change over time. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_copts", + "tf_cuda_library", + "tf_custom_op_library", + "tf_custom_op_library_additional_deps", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", +) +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load( + "@local_config_tensorrt//:build_defs.bzl", + "if_tensorrt", +) + +tf_cuda_cc_test( + name = "tensorrt_test_cc", + size = "small", + srcs = ["tensorrt_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + "//tensorflow/core:gpu_init", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_custom_op_library( + name = "python/ops/_trt_ops.so", + srcs = [ + "ops/get_serialized_resource_op.cc", + "ops/trt_engine_op.cc", + ], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +cc_library( + name = "trt_op_kernels", + srcs = [ + "kernels/get_serialized_resource_op.cc", + "kernels/trt_engine_op.cc", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":test_utils", + ":trt_allocator", + ":trt_conversion", + ":trt_logging", + ":trt_plugins", + ":trt_resources", + ":utils", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/core/grappler/costs:graph_properties", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), + alwayslink = 1, +) + +tf_cuda_cc_test( + name = "get_serialized_resource_op_test", + size = "small", + srcs = ["kernels/get_serialized_resource_op_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + # TODO(laigd): consider splitting get_serialized_resource_op out from + # TF-TRT. + ":trt_op_kernels", + ":trt_op_libs", + ":trt_resources", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + ], +) + +tf_gen_op_libs( + op_lib_names = [ + "trt_engine_op", + "get_serialized_resource_op", + ], +) + +cc_library( + name = "trt_op_libs", + deps = [ + ":get_serialized_resource_op_op_lib", + ":trt_engine_op_op_lib", + ], +) + +tf_cuda_library( + name = "trt_logging", + srcs = ["utils/trt_logger.cc"], + hdrs = ["utils/trt_logger.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_gen_op_wrapper_py( + name = "trt_ops", + deps = [ + ":trt_op_libs", + ], +) + +tf_custom_op_py_library( + name = "trt_ops_loader", + srcs = ["python/ops/trt_ops.py"], + dso = [ + "python/ops/_trt_ops.so", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), + kernels = [ + ":trt_op_kernels", + ":trt_op_libs", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:resources", + ], +) + +tf_cuda_library( + name = "trt_resources", + srcs = [ + "utils/trt_int8_calibrator.cc", + "utils/trt_resource_manager.cc", + "utils/trt_resources.cc", + ], + hdrs = [ + "utils/trt_int8_calibrator.h", + "utils/trt_lru_cache.h", + "utils/trt_resource_manager.h", + "utils/trt_resources.h", + ], + deps = [ + ":trt_allocator", + ":trt_logging", + ":utils", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_library( + name = "trt_allocator", + srcs = ["utils/trt_allocator.cc"], + hdrs = ["utils/trt_allocator.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cc_test( + name = "trt_allocator_test", + size = "small", + srcs = ["utils/trt_allocator_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":trt_allocator", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "trt_lru_cache_test", + size = "small", + srcs = ["utils/trt_lru_cache_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":trt_resources", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# Library for the node-level conversion portion of TensorRT operation creation +tf_cuda_library( + name = "trt_conversion", + srcs = [ + "convert/convert_graph.cc", + "convert/convert_nodes.cc", + "convert/trt_optimization_pass.cc", + ], + hdrs = [ + "convert/convert_graph.h", + "convert/convert_nodes.h", + "convert/trt_optimization_pass.h", + ], + deps = [ + ":segment", + ":test_utils", + ":trt_allocator", + ":trt_plugins", + ":trt_logging", + ":trt_resources", + ":utils", + "@com_google_absl//absl/strings", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:devices", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]) + tf_custom_op_library_additional_deps(), +) + +tf_cuda_cc_test( + name = "convert_graph_test", + size = "medium", + srcs = ["convert/convert_graph_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_conversion", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:direct_session", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_cc_test( + name = "convert_nodes_test", + size = "medium", + srcs = ["convert/convert_nodes_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_logging", + ":trt_conversion", + ":trt_plugins", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +# Library for the segmenting portion of TensorRT operation creation +cc_library( + name = "segment", + srcs = ["segment/segment.cc"], + hdrs = [ + "segment/segment.h", + "segment/union_find.h", + ], + deps = [ + "//tensorflow/core:graph", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + "@protobuf_archive//:protobuf_headers", + ], +) + +tf_cc_test( + name = "segment_test", + size = "small", + srcs = ["segment/segment_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":segment", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +# Library for the plugin factory +tf_cuda_library( + name = "trt_plugins", + srcs = [ + "plugin/trt_plugin.cc", + "plugin/trt_plugin_factory.cc", + "plugin/trt_plugin_utils.cc", + ], + hdrs = [ + "plugin/trt_plugin.h", + "plugin/trt_plugin_factory.h", + "plugin/trt_plugin_utils.h", + ], + deps = [ + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:tensorrt", + ]), +) + +tf_cuda_cc_test( + name = "trt_plugin_factory_test", + size = "small", + srcs = ["plugin/trt_plugin_factory_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_plugins", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:tensorrt", + ]), +) + +cc_library( + name = "utils", + srcs = ["convert/utils.cc"], + hdrs = ["convert/utils.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "test_utils", + srcs = ["utils/test_utils.cc"], + hdrs = ["utils/test_utils.h"], + deps = [ + "//tensorflow/core:lib", + "@com_googlesource_code_re2//:re2", + ], +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc similarity index 88% rename from tensorflow/contrib/tensorrt/convert/convert_graph.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 3b32f72bc1f220fd6730c71e3d2b3b6b806b748e..f3db42509ecf1d5176c8f56ef13d2c76d038ee7a 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" #include #include @@ -24,13 +24,14 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/contrib/tensorrt/segment/segment.h" -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" @@ -63,8 +64,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace convert { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; // Returns compiled TRT version information {Maj, Min, Patch} std::vector GetLinkedTensorRTVersion() { @@ -82,56 +83,62 @@ std::vector GetLoadedTensorRTVersion() { } TrtCandidateSelector::TrtCandidateSelector( - const grappler::GraphProperties& graph_properties, int precision_mode) + const grappler::GraphProperties& graph_properties, + TrtPrecisionMode precision_mode) : graph_properties_(graph_properties), precision_mode_(precision_mode) {} Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange static const std::set candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Sigmoid", - "Tanh", + "Abs", "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", "AvgPool", + "BatchMatMul", + "BiasAdd", "ConcatV2", + "Const", + "Conv2D", + "Conv2DBackpropInput", "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", "Div", - "RealDiv", - "Rsqrt", - "Reciprocal", "Exp", + "ExpandDims", + "FusedBatchNorm", + "FusedBatchNormV2", + "Identity", + "LeakyRelu", "Log", - "Sqrt", - "Abs", - "Neg", - "Transpose", - "Reshape", "MatMul", - "BatchMatMul", - "Softmax", - "Minimum", - "Maximum", - "TopKV2", - "Sum", - "Prod", "Max", + "MaxPool", + "Maximum", + "Mean", "Min", + "Minimum", + "Mul", + "Neg", + "Pad", + "Prod", + "RealDiv", + "Reciprocal", + "Relu", "Relu6", + "Reshape", + "Rsqrt", + "Rsqrt", + "Sigmoid", + "Snapshot", + "Softmax", + "Sqrt", "Square", + "Squeeze", + "StridedSlice", + "Sub", + "Sum", + "Tanh", + "TopKV2", + "Transpose", }; bool is_supported_op_type = (candidate_ops.count(node->type_string()) || @@ -145,10 +152,11 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // In INT8 mode, we will always apply the quantization ranges provided by // these ops to the relevant tensors. This happens regardless of the value of // use_calibration. - if (precision_mode_ == INT8MODE && quantize_ops.count(node->type_string())) { + if (precision_mode_ == TrtPrecisionMode::INT8 && + quantize_ops.count(node->type_string())) { is_supported_op_type = true; } - // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) + // LINT.ThenChange(//tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc) if (!is_supported_op_type) { return errors::Unimplemented("Op type ", node->type_string(), " is not supported"); @@ -208,6 +216,7 @@ tensorflow::Status ConvertCalibGraphToInferGraph( return tensorflow::errors::FailedPrecondition( "Need to run graph with calibration data first!"); } + tensorflow::core::ScopedUnref calib_sc(cres); if (cres->calibrator_) { cres->calibrator_->waitAndSetDone(); cres->thr_->join(); @@ -224,7 +233,6 @@ tensorflow::Status ConvertCalibGraphToInferGraph( return tensorflow::errors::Unknown( "Can't get TRTCalibrator from resource manager!"); } - cres->Unref(); TF_RETURN_IF_ERROR(calib_rm->Cleanup(container_name)); } } @@ -235,7 +243,7 @@ tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode, int minimum_segment_size, bool is_dyn_op, + TrtPrecisionMode precision_mode, int minimum_segment_size, bool is_dyn_op, int max_cached_engines, std::vector cached_engine_batches, bool use_calibration) { // Create GrapplerItem. @@ -295,7 +303,7 @@ tensorflow::Status ConvertGraphDefToTensorRT( parameters["max_batch_size"].set_i(max_batch_size); parameters["is_dynamic_op"].set_b(is_dyn_op); parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes); - TF_RETURN_IF_ERROR(GetPrecisionModeName( + TF_RETURN_IF_ERROR(TrtPrecisionModeToName( precision_mode, parameters["precision_mode"].mutable_s())); parameters["maximum_cached_engines"].set_i(max_cached_engines); if (!cached_engine_batches.empty()) { @@ -320,17 +328,23 @@ tensorflow::Status ConvertGraphDefToTensorRT( return Status::OK(); } +struct EdgePtrCompare { + bool operator()(const tensorflow::Edge* lhs, + const tensorflow::Edge* rhs) const { + return lhs->id() < rhs->id(); + } +}; + // Function to get subsegment information structure. tensorflow::Status GetEngineInfo( const tensorflow::Graph* g, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& segment_nodes, + const std::set& segment_nodes, const std::unordered_map& node_map, const std::vector& reverse_topo_order, EngineInfo* info) { - std::vector subgraph_node_ids; // Topologically sorted node ids. - std::set subgraph_node_names = segment_nodes; - std::set added_const_node_ids; // Used to prevent double insertion. + std::vector subgraph_nodes; // Topologically sorted nodes. + std::set added_const_nodes; // Used to prevent double insertion. std::set segment_devices; // Map from src_node_name+port to the unique port numbers of the TRT op, where @@ -342,26 +356,45 @@ tensorflow::Status GetEngineInfo( std::unordered_map input_to_engine_port, output_to_engine_port; for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); ++it) { - const auto& node_name = (*it)->name(); - if (segment_nodes.count(node_name) == 0) continue; - auto node = *it; + const Node* node = *it; + if (segment_nodes.count(node) == 0) continue; auto node_device = node->requested_device(); if (!node_device.empty()) { - segment_devices.insert(node_device); + // If device is CPU, treat as if no device was assigned. Don't add CPU to + // segment_device because that would cause a segfault in + // GetDeviceAndAllocator. This is because GetDeviceAndAllocator assumes + // any already set device is a GPU. + DeviceNameUtils::ParsedName parsed_name; + DeviceNameUtils::ParseFullName(node_device, &parsed_name); + if (parsed_name.type == "CPU") { + VLOG(1) << "Node " << node->name() << " was assigned to the CPU. " + << "Attempting to place on GPU."; + } else { + segment_devices.insert(node_device); + } } else { if (node->has_assigned_device_name()) { + // It appears that nodes will not have assigned devices at this point in + // execution. segment_devices.insert(node->assigned_device_name()); } else { VLOG(2) << "Node " << node->name() << " neither have requested device nor assigned device"; } } + subgraph_nodes.push_back(node); + const int node_id = node->id(); - subgraph_node_ids.push_back(node_id); - // Create input connections. - for (const auto edge : node->in_edges()) { + const string& node_name = node->name(); + + // Create input connections. Sort edges first to make determnistic since + // in_edges is a set of pointers. + std::vector in_edges(node->in_edges().begin(), + node->in_edges().end()); + std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare()); + for (const auto edge : in_edges) { auto input_node = edge->src(); - if (input_node->IsSource() || segment_nodes.count(input_node->name())) { + if (input_node->IsSource() || segment_nodes.count(input_node)) { continue; } if (edge->IsControlEdge()) { @@ -378,12 +411,11 @@ tensorflow::Status GetEngineInfo( // // Note that the segmenter already ensure that the constant data input // is valid and suppported by the engine. - if (!added_const_node_ids.insert(input_node->id()).second) { + if (!added_const_nodes.insert(input_node).second) { // Already added before. continue; } VLOG(1) << "Adding const node " << input_node->name(); - QCHECK(subgraph_node_names.insert(input_node->name()).second); // Since we already add (duplicate) the const input node to the segment // graphdef, it's now not a data dependency any more, but to make the // dependency correct we still add a control dependency. @@ -407,10 +439,14 @@ tensorflow::Status GetEngineInfo( node_id, edge->dst_input(), /*input_edge=*/true, port); } } - // Create output connections. - for (const auto edge : node->out_edges()) { + // Create output connections. Sort edges first to make determnistic since + // out_edges is a set of pointers. + std::vector out_edges(node->out_edges().begin(), + node->out_edges().end()); + std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare()); + for (const auto edge : out_edges) { auto output_node = edge->dst(); - if (output_node->IsSink() || segment_nodes.count(output_node->name())) { + if (output_node->IsSink() || segment_nodes.count(output_node)) { continue; } if (edge->IsControlEdge()) { @@ -438,12 +474,11 @@ tensorflow::Status GetEngineInfo( } // For each segment node in topological order. // Construct the const nodes first. - subgraph_node_ids.insert(subgraph_node_ids.begin(), - added_const_node_ids.begin(), - added_const_node_ids.end()); + subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(), + added_const_nodes.end()); TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( - g, graph_properties, subgraph_node_names, subgraph_node_ids, - &info->connections, &info->segment_graph_def, &info->engine_name)); + g, graph_properties, subgraph_nodes, &info->connections, + &info->segment_graph_def, &info->engine_name)); // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); @@ -564,6 +599,18 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } input_shape_protos.at(conn.port_number) = in_shape; input_shapes.at(conn.port_number) = conn.outside_shape; + // Shape must be fully defined (excluding batch dimension) for static + // mode. + if (info.engine_type == EngineInfo::EngineType::TRTStatic) { + for (int i = 1; i < conn.outside_shape.dims(); i++) { + if (conn.outside_shape.dim_size(i) <= 0) { + return tensorflow::errors::Internal( + "Input shapes must be fully defined when in static mode. " + "Please try is_dynamic_op=True (shape was ", + conn.outside_shape.DebugString(), ")"); + } + } + } // Rewrire data input if it's not found in original graph. tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); @@ -585,9 +632,17 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } } } + // We don't support segments with no inputs. Fall back to native TF here to + // avoid crash later. Constant folding should've folded the ops that make up + // these segments. + if (inputs.empty()) { + return tensorflow::errors::Internal( + "Segment has no inputs (possible " + "constfold failure)"); + } const bool calibrate_int8 = - (info.precision_mode == INT8MODE && info.use_calibration); + (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration); // Build the engine and get its serialized representation. string segment_string; if (info.engine_type == EngineInfo::EngineType::TRTStatic || calibrate_int8) { @@ -600,7 +655,8 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, TrtUniquePtrType engine; // TODO(sami): What happens if 1st dim is not batch? TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( - info.segment_graph_def, calibrate_int8 ? FP32MODE : info.precision_mode, + info.segment_graph_def, + calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode, max_batch_size, info.max_workspace_size_bytes, input_shapes, &trt_logger, alloc, /*calibrator=*/nullptr, &engine, info.use_calibration, @@ -616,14 +672,8 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, segment_string = info.segment_graph_def.SerializeAsString(); } - // TODO(aaroey): use enum instead, and add a helper method to do the - // conversion. string prec_string; - TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string)); - if (info.precision_mode == INT8MODE && calibrate_int8 && - !TRTResourceManager::instance()->getManager("TRTCalibration")) { - LOG(ERROR) << "Failed to construct calibration storage"; - } + TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string)); tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp"); if (!info.device.empty()) node_builder.Device(info.device); if (VLOG_IS_ON(1)) { @@ -639,7 +689,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } if (info.engine_type == EngineInfo::EngineType::TRTStatic && - info.cached_engine_batches.size()) { + !info.cached_engine_batches.empty()) { LOG(WARNING) << "Cached engine batches are ignored for static engines"; } tensorflow::NodeDef trt_node; @@ -653,7 +703,6 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, .Attr("serialized_segment", segment_string) .Attr("calibration_data", "") .Attr("max_cached_engines_count", info.maximum_cached_engines) - .Attr("cached_engine_batches", {max_batch_size}) .Attr("workspace_size_bytes", info.max_workspace_size_bytes) .Attr("precision_mode", prec_string) .Attr("use_calibration", info.use_calibration) @@ -805,6 +854,12 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( auto native_segment = fdeflib.add_function(); TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef( sgraph, StrCat(engine_name, "_native_segment"), native_segment)); + // Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on + // a GPU device as expected. Otherwise, some of the tensors of type DT_INT32 + // would be on host if the op generating the tensor has host memory tag set. + (*native_segment + ->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr] + .set_b(true); if (VLOG_IS_ON(7)) { VLOG(7) << engine_name << " Function_Def "; VLOG(7) << native_segment->DebugString(); @@ -926,7 +981,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { continue; } curr_engine.precision_mode = params.precision_mode; - if (params.use_calibration && params.precision_mode != INT8MODE) { + if (params.use_calibration && + params.precision_mode != TrtPrecisionMode::INT8) { return errors::InvalidArgument( "Calibration with FP32 or FP16 is not supported."); } @@ -995,27 +1051,31 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { cudaSetDevice(cuda_device_id); auto status = CreateTRTNode(engine_segments, i, params.max_batch_size, &graph, alloc.get(), &engine_nodes); - // If status is ok, we successfully added the node to the graph and can - // remove segment ops. Otherwise graph is not modified. + string msg = StrCat("TensorRT node ", engine.engine_name, " added for segment ", i, " consisting of ", converted_segments.at(i).first.size(), " nodes"); if (status.ok()) { LOG(INFO) << msg << " succeeded."; - for (auto node_name : converted_segments.at(i).first) { - graph.RemoveNode(node_map.at(node_name)); - } } else { // Graph is not modified. LOG(WARNING) << msg << " failed: " << status << ". Fallback to TF..."; } if (VLOG_IS_ON(1)) { msg = "Segment consists of nodes: "; - for (const string& node_name : converted_segments.at(i).first) { - StrAppend(&msg, node_name, ", "); + for (const Node* node : converted_segments.at(i).first) { + StrAppend(&msg, node->name(), ", "); } VLOG(1) << msg; } + + // If status is ok, we successfully added the node to the graph and can + // remove segment ops. Otherwise graph is not modified. + if (status.ok()) { + for (const Node* node : converted_segments.at(i).first) { + graph.RemoveNode(const_cast(node)); + } + } } cudaSetDevice(old_cuda_device); graph.ToGraphDef(params.output_graph_def); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h similarity index 86% rename from tensorflow/contrib/tensorrt/convert/convert_graph.h rename to tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index 1f39f56f6392ba33af3d74fec12c326ed4451cb6..95cf0227dcf84396b9de52194ae3a750f4acca66 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ #include -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -36,7 +36,7 @@ namespace convert { class TrtCandidateSelector { public: TrtCandidateSelector(const grappler::GraphProperties& graph_properties, - int precision_mode); + TrtPrecisionMode precision_mode); // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added // to TRT subgraph and later converted into TRT engine. @@ -52,7 +52,7 @@ class TrtCandidateSelector { const grappler::GraphProperties& graph_properties_; // Quantization ops are only converted when using quantized precisions. - const int precision_mode_; + const TrtPrecisionMode precision_mode_; }; struct ConversionParams { @@ -61,7 +61,7 @@ struct ConversionParams { max_batch_size(1), max_workspace_size_bytes(1 << 30), output_graph_def(nullptr), - precision_mode(1), + precision_mode(TrtPrecisionMode::FP32), minimum_segment_size(3), graph_properties(nullptr), cluster(nullptr), @@ -74,7 +74,7 @@ struct ConversionParams { size_t max_batch_size; size_t max_workspace_size_bytes; tensorflow::GraphDef* output_graph_def; - int precision_mode; + TrtPrecisionMode precision_mode; int minimum_segment_size; const tensorflow::grappler::GraphProperties* graph_properties; const tensorflow::grappler::Cluster* cluster; @@ -99,9 +99,10 @@ tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode = 1, int minimum_segment_size = 3, - bool is_dyn_op = false, int max_cached_engines = 1, - std::vector cached_engine_batches = {}, bool use_calibration = true); + TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32, + int minimum_segment_size = 3, bool is_dyn_op = false, + int max_cached_engines = 1, std::vector cached_engine_batches = {}, + bool use_calibration = true); // Method to call from optimization pass tensorflow::Status ConvertAfterShapes(ConversionParams& params); @@ -123,4 +124,4 @@ std::pair GetDeviceAndAllocator( #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc similarity index 96% rename from tensorflow/contrib/tensorrt/convert/convert_graph_test.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 2d2bfeb192c1893824c7b30bfad593c62c203392..cabc6ccfa13df77b3bd26f51f35284816141423a 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" #include #include #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -98,7 +98,8 @@ TEST(TrtCandidateSelector, Basics) { grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - for (const int precision_mode : {FP32MODE, INT8MODE}) { + for (const TrtPrecisionMode precision_mode : + {TrtPrecisionMode::FP32, TrtPrecisionMode::INT8}) { TrtCandidateSelector selector(graph_properties, precision_mode); TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); ExpectStatus( @@ -113,7 +114,7 @@ TEST(TrtCandidateSelector, Basics) { matmul_with_incompatible_input.operation.node()), error::INTERNAL, "Failed to convert input with index 0 to a TRT_TensorOrWeights"); - if (precision_mode == INT8MODE) { + if (precision_mode == TrtPrecisionMode::INT8) { TF_EXPECT_OK(selector.IsTensorRTCandidate(quantize.operation.node())); } else { ExpectStatus(selector.IsTensorRTCandidate(quantize.operation.node()), diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc similarity index 76% rename from tensorflow/contrib/tensorrt/convert/convert_nodes.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index fee095668e5aef44316ff15c1d8572b2ecd960df..79b1cba32909c119a9127c3d254a6f14a16cb660 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include #include @@ -24,11 +24,14 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT @@ -43,6 +46,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/platform/types.h" @@ -80,10 +84,17 @@ namespace tensorrt { const char* const kInputPHName = "TensorRTInputPH_"; const char* const kOutputPHName = "TensorRTOutputPH_"; +bool IsEngineInput(absl::string_view name) { + return absl::StartsWith(name, kInputPHName); +} +bool IsEngineOutput(absl::string_view name) { + return absl::StartsWith(name, kOutputPHName); +} + namespace convert { +using absl::StrAppend; +using absl::StrCat; using ::tensorflow::str_util::Split; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, nvinfer1::DataType* trt_dtype) { @@ -120,6 +131,15 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, return trt_dims; } +Status TensorShapeArrayToTrtDims(const std::vector& shape, + nvinfer1::Dims* out, + bool ignore_first_dim = false) { + PartialTensorShape tensor_shape; + TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape)); + *out = TensorShapeToTrtDims(tensor_shape, ignore_first_dim); + return tensorflow::Status::OK(); +} + void GetOutputProperties(const grappler::GraphProperties& graph_properties, const Node* node, const int out_port, PartialTensorShape* shape, @@ -325,6 +345,41 @@ Status Converter::GetTrtBroadcastShape( return Status::OK(); } +nvinfer1::ITensor* Converter::CreateConstantLayer( + const TRT_ShapedWeights& weights, const nvinfer1::Dims& dims) { + nvinfer1::Weights trt_weights = weights.GetTrtWeights(); + nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights); + if (!layer) return nullptr; + const nvinfer1::DataType trt_dtype = trt_weights.type; + nvinfer1::ITensor* trt_tensor = layer->getOutput(0); + // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set + // the data type below, it will always be kFLOAT regardless what the data type + // of the weights is. Once NVIDIA fixes this bug, we should remove the data + // type setting logic below and test should still pass. + trt_tensor->setType(trt_dtype); + return trt_tensor; +} + +tensorflow::Status CreateBroadcastableScalarConstant( + OpConverterParams* params, float value, const nvinfer1::Dims& dims, + const nvinfer1::ITensor** tensor) { + // In order to be broadcastable, the number of dims has to match. + nvinfer1::Dims broadcastable_dims(dims); + for (int i = 0; i < broadcastable_dims.nbDims; i++) { + broadcastable_dims.d[i] = 1; + } + TRT_ShapedWeights weights = params->weight_store->GetTempWeights( + tensorflow::DataType::DT_FLOAT, broadcastable_dims); + auto weights_ptr = + static_cast(const_cast(weights.GetValues())); + weights_ptr[0] = value; + *tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name()); + params->converter->ProvideQuantizationRange( + const_cast(*tensor), value, value); + return Status::OK(); +} + inline bool DimsEqual(const nvinfer1::Dims& dim_l, const nvinfer1::Dims& dim_r) { if (dim_l.nbDims != dim_r.nbDims) { @@ -623,6 +678,11 @@ bool TFAttrs::get(const string& key) const { return this->at(key)->b(); } +template <> +int TFAttrs::get(const string& key) const { + return this->at(key)->i(); +} + // TODO(jie): reorder4 & reorder2 should be merged? // TODO(aaroey): fix the order of parameters. template @@ -834,7 +894,7 @@ Status TrtNodeValidator::ConvertConstToWeights( } Converter::Converter(nvinfer1::INetworkDefinition* trt_network, - int precision_mode, bool use_calibration) + TrtPrecisionMode precision_mode, bool use_calibration) : trt_network_(trt_network), precision_mode_(precision_mode), use_calibration_(use_calibration) { @@ -865,9 +925,11 @@ Status Converter::ConvertNode(const NodeDef& node_def) { // We need to check the name before setting it. If the input is one of the // engine input, setting the name here will overwrite engine input // bindings which will cause runtime error. + // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer + // in ConvertIdentity. if (output.is_tensor()) { const char* tensor_name = output.tensor()->getName(); - if (!tensorflow::str_util::StartsWith(tensor_name, kInputPHName)) { + if (!IsEngineInput(tensor_name)) { // TRT initializes tensor names as "(Unnamed ITensor* N)". We rename // them to match their corresponding TensorFlow name. // Note: ITensors that we create internally within TF-TRT which are @@ -913,22 +975,45 @@ Status Converter::AddInputTensor(const string& name, nvinfer1::DataType dtype, } Status Converter::RenameAndMarkOutputTensors( - const std::vector>& output_tensors) { + const std::vector& output_tensors) { for (const auto& output : output_tensors) { TRT_TensorOrWeights tensor_or_weights; - TF_RETURN_IF_ERROR(GetTensorOrWeights(output.first, &tensor_or_weights)); + TF_RETURN_IF_ERROR( + GetTensorOrWeights(output.source_tensor_name, &tensor_or_weights)); if (!tensor_or_weights.is_tensor()) { - return errors::InvalidArgument("Output ", output.first, + return errors::InvalidArgument("Output ", output.source_tensor_name, " is weights not tensor"); } nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); if (tensor == nullptr) { - return errors::NotFound("Output tensor not found: ", output.first); + return errors::NotFound("Output tensor not found: ", + output.source_tensor_name); } - tensor->setName(output.second.c_str()); - VLOG(1) << "Marking output tensor " << output.first << ", as output tensor " - << output.second; + // Check if this tensor has already been marked as an input or output. + // + // ConvertIdentity can cause the same tensor to be repeated in + // output_tensors, which can cause us to overwrite the name of the output + // tensor binding. For example, if we rename OutputPH_0 to OutputPH_1 then + // we won't be able to locate OutputPH_0 during runtime. To fix this, + // duplicate the tensor using no-op shuffle. + // + // TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer + // in ConvertIdentity. + if (IsEngineInput(tensor->getName()) || IsEngineOutput(tensor->getName())) { + // Using shuffle layer for identity by not setting reshape or transpose. + nvinfer1::IShuffleLayer* layer = network()->addShuffle(*tensor); + TFTRT_RETURN_ERROR_IF_NULLPTR( + layer, StrCat("Output Copy for ", tensor->getName())); + MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); + tensor = layer->getOutput(0); + } + tensor->setName(output.dest_node_name.c_str()); network()->markOutput(*tensor); + // Set type after marking as output. TRT only supports setType for engine + // outputs and inputs (type is inferred otherwise). + tensor->setType(output.trt_dtype); + VLOG(1) << "Marking output TRT tensor " << output.source_tensor_name + << ", which feeds TF node " << output.dest_node_name; } return Status::OK(); } @@ -1072,11 +1157,9 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, *tensor = layer->getOutput(0); } } else { - nvinfer1::IConstantLayer* layer = - this->network()->addConstant(dims, input.weights().GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape"); - *tensor = layer->getOutput(0); - if (precision_mode() == INT8MODE && !use_calibration()) { + *tensor = CreateConstantLayer(input.weights(), dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, "TF-TRT Internal Reshape"); + if (precision_mode() == TrtPrecisionMode::INT8 && !use_calibration()) { // If we are in int8 mode and not calibrating, we need to explicitly set a // quantization range for the output tensor of the IConstantLayer. Here we // set the range to [min(weights), max(weights)]. @@ -1111,7 +1194,7 @@ void Converter::ProvideQuantizationRange(nvinfer1::ITensor* tensor, } void Converter::MaybeApplyQuantizationRanges() { - if (precision_mode() != INT8MODE) return; + if (precision_mode() != TrtPrecisionMode::INT8) return; // Infer ranges across marked ops. PropagateQuantizationRanges(); @@ -1234,6 +1317,39 @@ Status Converter::GetInputs(const tensorflow::NodeDef& node_def, return tensorflow::Status::OK(); } +// Checks that the number of inputs match, and enforces that the inputs marked +// as true are constant weights. true means that the input must be a weight, +// while false means the input must be a tensor. In the future, false will mean +// the input can be a tensor or weight. +tensorflow::Status CheckInputsWeights( + const OpConverterParams& params, + const std::vector>& inputs_is_weight) { + const auto& inputs = params.inputs; + const auto& node_def = params.node_def; + if (inputs.size() != inputs_is_weight.size()) { + return tensorflow::errors::InvalidArgument( + node_def.op(), " got ", inputs.size(), " inputs but expected ", + inputs_is_weight.size(), ", at ", node_def.name()); + } + for (int i = 0; i < inputs.size(); i++) { + if (inputs_is_weight[i].second && inputs.at(i).is_tensor()) { + return tensorflow::errors::Unimplemented( + "The input \"", inputs_is_weight[i].first, "\" for ", node_def.op(), + " must be a constant, at ", node_def.name()); + } + // TODO(tmorris): Remove this check and provide a method to automatically + // retrive an input as a tensor, converting via CreateConstantLayer if it + // was originally a weight. We will want a caching mechanism to prevent many + // duplicate constants from being created. + if (!inputs_is_weight[i].second && inputs.at(i).is_weights()) { + return tensorflow::errors::Unimplemented( + "The input \"", inputs_is_weight[i].first, "\" for ", node_def.op(), + " must be a tensor, at ", node_def.name()); + } + } + return tensorflow::Status::OK(); +} + TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store, const TRT_ShapedWeights& weights_src) { auto dtype_new = tensorflow::DataType::DT_HALF; @@ -1426,7 +1542,7 @@ Status BinaryTensorOpWeight(OpConverterParams* params, const_cast(tensor), permutation, &tensor)); } - if (params->converter->precision_mode() == FP16MODE) { + if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { weights = ConvertFP32ToFP16(params->weight_store, weights); } @@ -1469,7 +1585,7 @@ Status BinaryTensorOpWeight(OpConverterParams* params, // Because of this issue, fall back to BinaryTensorOpTensor if we are // doing INT8 with no calibration. There is most likely no performance // penalty by falling back here. - if (params->converter->precision_mode() == INT8MODE && + if (params->converter->precision_mode() == TrtPrecisionMode::INT8 && !params->converter->use_calibration()) { return errors::Unimplemented( "Intermediate quantization range cannot be determined without" @@ -1519,80 +1635,126 @@ Status BinaryTensorOpWeight(OpConverterParams* params, return tensorflow::Status::OK(); } -enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV }; - -tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { +tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group, + bool is_conv2d_backprop_input) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + TRT_TensorOrWeights backprop_output_size; + const nvinfer1::ITensor* tensor = nullptr; + if (is_conv2d_backprop_input) { + // In the case when Conv2dBackpropInput is used for conv2d_transpose, these + // inputs correspond to: output size, filter, and input. + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, + {{"input_sizes", true}, {"filter", true}, {"out_backprop", false}})); + backprop_output_size = inputs.at(0); + tensor = inputs.at(2).tensor(); + } else { + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"filter", true}})); + tensor = inputs.at(0).tensor(); + } + TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); + if (weights_rsck.shape_.nbDims != 4) { + return tensorflow::errors::InvalidArgument( + "Conv2D expects kernel of dimension 4, at " + node_def.name()); + } TFAttrs attrs(node_def); - - int h_index = 2; - int w_index = 3; auto data_format = attrs.get("data_format"); - if (data_format == "NHWC") { + int c_index = (data_format == "NHWC") ? 3 : 1; + int h_index = (data_format == "NHWC") ? 1 : 2; + int w_index = (data_format == "NHWC") ? 2 : 3; + auto tf_dilations = attrs.get>("dilations"); + if (tf_dilations.size() != 4) { + return tensorflow::errors::InvalidArgument( + "Convolution dilations field must specify 4 dimensions, at ", + node_def.name()); + } + if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) { + return tensorflow::errors::Unimplemented( + "Dilation rate must be 1 for batch and channel dimensions, at ", + node_def.name()); + } + const nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]); + if (is_conv2d_backprop_input && (dilation.d[0] != 1 || dilation.d[1] != 1)) { + return tensorflow::errors::Unimplemented( + "Dilation with Conv2DBackpropInput (conv2d_transpose) is not supported", + ", at ", node_def.name()); + } + + const auto tf_stride = attrs.get>("strides"); + if (tf_stride.size() != 4) { + return tensorflow::errors::InvalidArgument( + "Convolution strides field must specify 4 dimensions, at ", + node_def.name()); + } + if (tf_stride[0] != 1 || tf_stride[c_index] != 1) { + return tensorflow::errors::Unimplemented( + "Stride must be 1 for batch and channel dimensions, at ", + node_def.name()); + } + const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + if (params->validation_only) return tensorflow::Status::OK(); + + // Transpose to NCHW (NCHW is required for IConvLayer). + const bool need_transpose = (data_format == "NHWC"); + if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( const_cast(tensor), {0, 3, 1, 2}, &tensor)); - h_index = 1; - w_index = 2; - // TODO(jie): transpose it } - - // tensor after transpose (NCHW) + // Dimensions of transposed tensor. const auto tensor_dim = tensor->getDimensions(); - int num_groups = group; - if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution - VLOG(2) << "groups count: " << num_groups; + // group == 0 signifies that this is a depthwise convolution, so set + // num_groups to size of input's channel dim. For a non-depthwise conv, + // num_groups will be 1. + const int num_groups = (group == 0) ? tensor_dim.d[0] : group; - TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); - VLOG(2) << "weight shape: " << weights_rsck.DebugString(); - if (weights_rsck.shape_.nbDims != 4) { - return tensorflow::errors::Internal( - "Conv2D expects kernel of dimension 4, at: " + node_def.name()); - } - if (params->converter->precision_mode() == FP16MODE) { - weights_rsck = - ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); + if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { + weights_rsck = ConvertFP32ToFP16(params->weight_store, weights_rsck); } - + // For conv, TF weights are RSCK, and TRT expects KCRS. + // For backprop, TF weights are RSKC, and TRT expects CKRS. + // Therefore, this reorder will work for both cases. TRT_ShapedWeights weights = params->weight_store->GetTempWeights(weights_rsck); ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); TRT_ShapedWeights biases(weights.type_); - const int noutput = weights.shape_.d[0] * num_groups; + const int output_axis = is_conv2d_backprop_input ? 1 : 0; + const int noutput = weights.shape_.d[output_axis] * num_groups; nvinfer1::DimsHW kernel_size; kernel_size.h() = weights.shape_.d[2]; kernel_size.w() = weights.shape_.d[3]; - VLOG(2) << "RSCK: " << weights.DebugString(); - VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w(); - - // TODO(jie): stride. (NHWC/NCHW) - const auto tf_stride = attrs.get>("strides"); - VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index; - VLOG(2) << "stride: " << tf_stride[0] << tf_stride[1] << tf_stride[2] - << tf_stride[3]; - const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + // Add padding. std::vector> padding; - // TODO(jie): padding. if (attrs.get("padding") == "SAME") { - // This is NCHW tensor with no batch dimension. - // 1 -> h - // 2 -> w - padding = CreateSamePadding( - stride, kernel_size, - {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); + nvinfer1::DimsHW effective_kernel_size = kernel_size; + effective_kernel_size.h() += (kernel_size.h() - 1) * (dilation.h() - 1); + effective_kernel_size.w() += (kernel_size.w() - 1) * (dilation.w() - 1); + std::vector input_dims; + if (is_conv2d_backprop_input) { + // For backprop, calculate padding based on "input_sizes" input, which + // actually corresponds to output size. ("input_sizes" makes sense in the + // context of Conv2DBackpropInput). + // We use h_index and w_index instead of 1 and 2 because we havent + // transposed backprop_output_size along with the input. + auto output_size_weights = static_cast( + const_cast(backprop_output_size.weights().GetValues())); + input_dims = {output_size_weights[h_index], output_size_weights[w_index]}; + } else { + // Use 1 and 2 because tensor_dim has the dimensions of the transposed + // input. + input_dims = {static_cast(tensor_dim.d[1]), + static_cast(tensor_dim.d[2])}; + } + padding = CreateSamePadding(stride, effective_kernel_size, input_dims); } else { padding = {{0, 0}, {0, 0}}; } - if (padding[0].first != padding[0].second || padding[1].first != padding[1].second) { - // TODO(jie): handle asymmetric padding - VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second - << padding[1].first << padding[1].second; - VLOG(2) << "TENSOR before: " << DebugString(tensor->getDimensions()); + // Handle asymmetric padding. auto pad_layer = params->converter->network()->addPadding( *const_cast(tensor), nvinfer1::DimsHW(padding[0].first, padding[1].first), @@ -1602,24 +1764,38 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { const_cast(tensor), pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); - VLOG(2) << "TENSOR after: " << DebugString(tensor->getDimensions()); } - nvinfer1::IConvolutionLayer* layer = - params->converter->network()->addConvolution( - *const_cast(tensor), noutput, kernel_size, - weights.GetTrtWeights(), biases.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // Add convolution. + nvinfer1::ILayer* conv_layer = nullptr; + if (is_conv2d_backprop_input) { + nvinfer1::IDeconvolutionLayer* layer = + params->converter->network()->addDeconvolution( + *const_cast(tensor), noutput, kernel_size, + weights.GetTrtWeights(), biases.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setStride(stride); + layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); + layer->setNbGroups(num_groups); + conv_layer = layer; + } else { + nvinfer1::IConvolutionLayer* layer = + params->converter->network()->addConvolution( + *const_cast(tensor), noutput, kernel_size, + weights.GetTrtWeights(), biases.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setStride(stride); + layer->setPadding({padding[0].first, padding[1].first}); + layer->setName(node_def.name().c_str()); + layer->setNbGroups(num_groups); + layer->setDilation(dilation); + conv_layer = layer; + } + const nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); - layer->setStride(stride); - layer->setPadding({padding[0].first, padding[1].first}); - layer->setName(node_def.name().c_str()); - layer->setNbGroups(num_groups); - const nvinfer1::ITensor* output_tensor = layer->getOutput(0); - VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions()); - VLOG(2) << "data_format: " << data_format; - if (data_format == "NHWC") { - // TODO(jie): transpose it back! + // Restore transpose. + if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( const_cast(output_tensor), {0, 2, 3, 1}, &output_tensor)); @@ -1629,18 +1805,6 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { return tensorflow::Status::OK(); } -tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, - ConvolutionType type) { - switch (type) { - case ConvolutionType::DEFAULT: - return ConvertConv2DHelper(params, 1); - case ConvolutionType::DEPTHWISE_CONV: - return ConvertConv2DHelper(params, 0); - } - return tensorflow::errors::Unimplemented("unsupported convolution type at, " + - params->node_def.name()); -} - Status BinaryTensorOpTensor(OpConverterParams* params, const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r) { @@ -1668,6 +1832,13 @@ Status BinaryTensorOpTensor(OpConverterParams* params, "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ", status.error_message()); } + TFAttrs attrs(node_def); + nvinfer1::DataType dtype = attrs.get("T"); + if (dtype == nvinfer1::DataType::kINT32) { + return errors::Unimplemented("Binary op ", node_def.op(), + " does not support INT32, at ", + node_def.name()); + } if (params->validation_only) return Status::OK(); const nvinfer1::ITensor* tensor_l = nullptr; @@ -1684,8 +1855,6 @@ Status BinaryTensorOpTensor(OpConverterParams* params, } // Check type consistency. - TFAttrs attrs(node_def); - nvinfer1::DataType dtype = attrs.get("T"); TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype) << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype); TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype) @@ -1745,12 +1914,8 @@ tensorflow::Status ConvertPlugin(OpConverterParams* params) { tensorflow::Status ConvertTranspose(OpConverterParams* params) { const auto& inputs = params->inputs; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at ", params->node_def.name()); - } - + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"x", false}, {"perm", true}})); // Get the permutation from weights. TRT_ShapedWeights weights = inputs.at(1).weights(); const int* weights_ptr = @@ -1783,11 +1948,8 @@ tensorflow::Status ConvertTranspose(OpConverterParams* params) { tensorflow::Status ConvertReshape(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects weights for shape, at ", node_def.name()); - } - + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"tensor", false}, {"shape", true}})); TRT_TensorOrWeights input_tensor = inputs.at(0); TRT_ShapedWeights weights = inputs.at(1).weights(); if (weights.count() == 0) { @@ -1880,20 +2042,378 @@ tensorflow::Status ConvertReshape(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertExpandDims(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"axis", true}})); + // Get input shape as vector. + TRT_TensorOrWeights input_tensor = inputs.at(0); + const nvinfer1::Dims dims = input_tensor.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dim back. + input_dims.insert(input_dims.begin(), -1); + const int input_rank = input_dims.size(); + // Get axis to expand on. + TRT_ShapedWeights weights = inputs.at(1).weights(); + if (weights.count() != 1) { + return tensorflow::errors::InvalidArgument( + "ExpandDims axis must be a scalar, at ", node_def.name()); + } + const int* weights_ptr = + static_cast(const_cast(weights.GetValues())); + int axis = weights_ptr[0]; + // Make sure axis is valid. + if ((axis < (-input_rank - 1)) || (axis > input_rank)) { + return tensorflow::errors::InvalidArgument( + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at ", + node_def.name()); + } + // Convert negative axis to corresponding positive axis. + if (axis < 0) axis += input_rank + 1; + if (axis == 0) { + return tensorflow::errors::Unimplemented( + "Modifying batch dimension is not supported for ExpandDims, at ", + node_def.name()); + } + if (params->validation_only) return Status::OK(); + + // ExpandDims: Insert new dim of size 1. + input_dims.insert(input_dims.begin() + axis, 1); + // Reshape tensor. + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input_tensor, new_dims, &output_tensor)); + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(output_tensor))); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertSqueeze(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + // Get input shape. + TRT_TensorOrWeights input_tensor = inputs.at(0); + const nvinfer1::Dims dims = input_tensor.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dim back. + input_dims.insert(input_dims.begin(), -1); + const int input_rank = input_dims.size(); + // Mark axes to remove by setting them to 0. + TFAttrs attrs(node_def); + auto squeeze_dims = attrs.get>("squeeze_dims"); + if (squeeze_dims.size() == 0) { + return tensorflow::errors::Unimplemented( + "Squeeze is only implemented for explicit dims, at ", node_def.name()); + } + for (int axis : squeeze_dims) { + // Make sure axis is valid. + if ((axis < -input_rank) || (axis >= input_rank)) { + return tensorflow::errors::InvalidArgument( + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at ", + node_def.name()); + } + // Convert negative axis to corresponding positive axis. + if (axis < 0) axis += input_rank; + // Don't squeeze batch dim. + if (axis == 0) { + return tensorflow::errors::Unimplemented( + "Cannot squeeze batch dimension, at ", node_def.name()); + } + // Make sure target dimension is size 1. + if (input_dims[axis] != 1) { + return tensorflow::errors::InvalidArgument( + "Cannot squeeze a dimension which isn't size 1, at ", + node_def.name()); + } + // Mark dim for removal by setting to 0. + input_dims[axis] = 0; + } + if (params->validation_only) return Status::OK(); + + // Remove all dims which are equal to 0. + input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), + input_dims.end()); + // Reshape tensor. + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input_tensor, new_dims, &output_tensor)); + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(output_tensor))); + return tensorflow::Status::OK(); +} + +// Gets the bounds (start or end) from the weights of a StridedSlice op. +tensorflow::Status GetStridedSliceBound(const std::vector& input_dims, + const TRT_ShapedWeights& bound_weights, + int mask, bool begin, string node_name, + std::vector* output_bound) { + const string bound_name = (begin) ? "begin" : "end"; + const int* weights_ptr = static_cast(bound_weights.GetValues()); + *output_bound = + std::vector(weights_ptr, weights_ptr + bound_weights.count()); + if (output_bound->size() != input_dims.size()) { + return tensorflow::errors::InvalidArgument( + "StridedSlice \"", bound_name, "\" specified ", + std::to_string(output_bound->size()), " dimensions, but input rank is ", + std::to_string(input_dims.size()), ", at ", node_name); + } + for (int i = 0; i < output_bound->size(); i++) { + if ((1 << i) & mask) { + // Apply mask. + (*output_bound)[i] = (begin) ? 0 : input_dims[i]; + // Masked bound will always result in a valid, non-negative bound, so we + // don't need the following checks. For the common case of using masks on + // a undefined batch dim (-1), we specifically don't want to do the + // following checks because they will erroneously detect an out of range + // bound or try to correct the negative value. + continue; + } + // Make sure bound is valid. + if (((*output_bound)[i] < -input_dims[i]) || + ((*output_bound)[i] > input_dims[i])) { + return tensorflow::errors::InvalidArgument( + bound_name, " value of ", std::to_string((*output_bound)[i]), + " for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at ", + node_name); + } + // Convert negative values to their positive equivalent. + if ((*output_bound)[i] < 0) { + (*output_bound)[i] += input_dims[i]; + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, + {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}})); + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + if (inputs.at(0).is_tensor()) { + // Temporarily add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + } + if (input_dims.size() > 4) { + return tensorflow::errors::Unimplemented( + "StridedSlice is not implemented for tensors with rank > 4, at ", + node_def.name()); + } + TFAttrs attrs(node_def); + // Get begin and end bounds per axis. + std::vector begin, end; + TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(1).weights(), + attrs.get("begin_mask"), true, + node_def.name(), &begin)); + TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(2).weights(), + attrs.get("end_mask"), false, + node_def.name(), &end)); + // Get strides per axis (must all be 1). + TRT_ShapedWeights stride_weights = inputs.at(3).weights(); + const int* stride_weights_ptr = static_cast(stride_weights.GetValues()); + std::vector strides(stride_weights_ptr, + stride_weights_ptr + stride_weights.count()); + for (int x : strides) { + if (x != 1) { + return tensorflow::errors::Unimplemented( + "StridedSlice is only implemented for stride of 1, at ", + node_def.name()); + } + } + // Unsupported mask options. + for (const string& attr : + {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { + int attr_val = attrs.get(attr); + if (attr_val != 0) { + return tensorflow::errors::Unimplemented( + attr, " is not supported for StridedSlice, at ", node_def.name()); + } + } + + nvinfer1::ITensor* tensor = + const_cast(inputs.at(0).tensor()); + // Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input. + const bool need_reshape = (input_dims.size() != 4); + int reshape_dims_added = 0; + nvinfer1::Dims reshape_dims; + if (need_reshape) { + // Add new dims after batch dim until tensor is 4D. + while (input_dims.size() < 4) { + input_dims.insert(input_dims.begin() + 1, 1); + begin.insert(begin.begin() + 1, 0); + end.insert(end.begin() + 1, 1); + reshape_dims_added++; + } + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &reshape_dims, + /*ignore_first_dim=*/true)); + } + // Find dimensions which need to be sliced. + std::vector pad_dims; + for (int i = 0; i < input_dims.size(); i++) { + if ((begin[i] != 0) || (end[i] != input_dims[i])) { + if (i == 0) { + return tensorflow::errors::Unimplemented( + "StridedSlice can't modify batch dim, at ", node_def.name()); + } else if ((end[i] - begin[i]) < 0) { + return tensorflow::errors::InvalidArgument( + "New size of sliced dimension is negative, at ", node_def.name()); + } + pad_dims.push_back(i); + } + } + if (pad_dims.size() == 0) { + // No dimensions are changed. We could create a padding layer anyway with + // values of 0. + if (params->validation_only) return Status::OK(); + params->outputs->push_back(inputs.at(0)); + return tensorflow::Status::OK(); + } else if (pad_dims.size() == 1) { + // Only one dim is modified but we have to have 2, mark a second dim which + // will have padding of 0. The dim we add is chosen to avoid an unecessary + // transpose. + if (pad_dims[0] != 2) { + pad_dims.push_back(2); + } else { + pad_dims.push_back(3); + } + } else if (pad_dims.size() > 2) { + return tensorflow::errors::Unimplemented( + "StridedSlice can only modify 2 dimensions, at ", node_def.name()); + } + std::sort(pad_dims.begin(), pad_dims.end()); + // Convert to pre/post padding values. Since TRT does not have a StridedSlice + // or Slice layer, we instead create an IPaddingLayer with negative padding. + nvinfer1::DimsHW pre_padding, post_padding; + for (int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + pre_padding.d[i] = -begin[axis]; + post_padding.d[i] = end[axis] - input_dims[axis]; + } + + // IPaddingLayer will always apply the padding to dims 2,3 (input format is + // NCHW). + const bool need_transpose = !(pad_dims[0] == 2 && pad_dims[1] == 3); + std::vector transpose_order(input_dims.size()); + std::vector inv_transpose_order(input_dims.size()); + if (need_transpose) { + if (pad_dims[0] == 1 && pad_dims[1] == 3) { + transpose_order = {0, 2, 1, 3}; + inv_transpose_order = {0, 2, 1, 3}; + } else if (pad_dims[0] == 1 && pad_dims[1] == 2) { + transpose_order = {0, 3, 1, 2}; + inv_transpose_order = {0, 2, 3, 1}; + } + } + if (params->validation_only) return Status::OK(); + + // Start conversion. + if (need_reshape) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + inputs.at(0), reshape_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + + // Add padding layer + nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( + *const_cast(tensor), pre_padding, post_padding); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + params->converter->MarkQuantizationRangesAsInferrable(tensor, + layer->getOutput(0)); + tensor = layer->getOutput(0); + + // Restore transpose + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, inv_transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + // Restore reshape + if (need_reshape) { + // Calculate output dimensions + for (int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + input_dims[axis] = end[axis] - begin[axis]; + } + // Remove added 1 dimensions + for (int i = 0; i < reshape_dims_added; i++) { + int value = input_dims[1]; + if (value != 1) { + return tensorflow::errors::Internal( + "StridedSlice error when reshaping, at ", node_def.name()); + } + input_dims.erase(input_dims.begin() + 1); + } + + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(tensor), new_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(tensor))); + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertConv2D(OpConverterParams* params) { - return ConvertConv2DHelper(params, ConvolutionType::DEFAULT); + return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/false); } tensorflow::Status ConvertConv2DDepthwise(OpConverterParams* params) { - return ConvertConv2DHelper(params, ConvolutionType::DEPTHWISE_CONV); + return ConvertConv2DHelper(params, 0, /*is_conv2d_backprop_input=*/false); +} + +tensorflow::Status ConvertConv2DBackpropInput(OpConverterParams* params) { + return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/true); } tensorflow::Status ConvertPool(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + nvinfer1::PoolingType type; + if (node_def.op() == "MaxPool") { + type = nvinfer1::PoolingType::kMAX; + } else if (node_def.op() == "AvgPool") { + type = nvinfer1::PoolingType::kAVERAGE; + } else { + return tensorflow::errors::Unimplemented( + "Unsupported pooling type: ", node_def.op(), ", at ", node_def.name()); + } TFAttrs attrs(node_def); + const string padding_type = attrs.get("padding"); + if ((padding_type != "SAME") && (padding_type != "VALID")) { + return tensorflow::errors::Unimplemented( + "Unsupported padding type: ", padding_type, ", at ", node_def.name()); + } + if (params->validation_only) return Status::OK(); + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); int h_index = 2; int w_index = 3; const auto data_format = attrs.get("data_format"); @@ -1904,16 +2424,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { const_cast(tensor), {0, 3, 1, 2}, &tensor)); } - nvinfer1::PoolingType type; - if (node_def.op() == "MaxPool") { - type = nvinfer1::PoolingType::kMAX; - } else if (node_def.op() == "AvgPool") { - type = nvinfer1::PoolingType::kAVERAGE; - } else { - return tensorflow::errors::Unimplemented("Unsupported pool type: ", - node_def.op()); - } - const auto tf_stride = attrs.get>("strides"); const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); @@ -1922,7 +2432,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { auto tensor_dim = tensor->getDimensions(); std::vector> padding; - const string padding_type = attrs.get("padding"); if (padding_type == "SAME") { // This is NCHW tensor with no batch dimension. // 1 -> h @@ -1932,9 +2441,6 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); } else if (padding_type == "VALID") { padding = {{0, 0}, {0, 0}}; - } else { - return tensorflow::errors::Unimplemented("Unsupported padding type: ", - padding_type); } if (padding[0].first != padding[0].second || @@ -1976,7 +2482,9 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { return tensorflow::Status::OK(); } -tensorflow::Status ConvertActivation(OpConverterParams* params) { +// TODO(tmorris): Use ActivationType::kLEAKY_RELU in TRT 5.1+ once perf +// improves. +tensorflow::Status ConvertLeakyRelu(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 1) { @@ -1988,6 +2496,47 @@ tensorflow::Status ConvertActivation(OpConverterParams* params) { node_def.op(), " is only implemented for tensors, at ", node_def.name()); } + TFAttrs attrs(node_def); + const float alpha = attrs.get("alpha"); + if (alpha < 0.0f || alpha > 1.0f) { + return tensorflow::errors::Unimplemented( + "Alpha value for LeakyRelu must be between 0 and 1, at ", + node_def.name()); + } + if (params->validation_only) return tensorflow::Status::OK(); + + // Input Tensor + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + // Create const for alpha. + const nvinfer1::ITensor* const_alpha_tensor = nullptr; + TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant( + params, alpha, tensor->getDimensions(), &const_alpha_tensor)); + // alpha * x + nvinfer1::IElementWiseLayer* mul_layer = + params->converter->network()->addElementWise( + *const_cast(tensor), + *const_cast(const_alpha_tensor), + nvinfer1::ElementWiseOperation::kPROD); + TFTRT_RETURN_ERROR_IF_NULLPTR(mul_layer, node_def.name()); + // max(x, alpha * x) + nvinfer1::IElementWiseLayer* max_layer = + params->converter->network()->addElementWise( + *const_cast(tensor), + *const_cast(mul_layer->getOutput(0)), + nvinfer1::ElementWiseOperation::kMAX); + TFTRT_RETURN_ERROR_IF_NULLPTR(max_layer, node_def.name()); + nvinfer1::ITensor* output_tensor = max_layer->getOutput(0); + params->converter->MarkQuantizationRangesAsInferrable( + output_tensor, const_cast(mul_layer->getOutput(0))); + + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return Status::OK(); +} + +tensorflow::Status ConvertActivation(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); static const std::unordered_map ops{ {"Relu", nvinfer1::ActivationType::kRELU}, {"Sigmoid", nvinfer1::ActivationType::kSIGMOID}, @@ -2021,19 +2570,19 @@ tensorflow::Status ConvertActivation(OpConverterParams* params) { Status ConvertQuantize(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if ((inputs.size() == 0) || - (node_def.op() == "FakeQuantWithMinMaxArgs" && inputs.size() != 1) || - (node_def.op() == "FakeQuantWithMinMaxVars" && inputs.size() != 3) || - (node_def.op() == "QuantizeAndDequantizeV2" && inputs.size() != 3) || - (node_def.op() == "QuantizeAndDequantizeV3" && inputs.size() != 4)) { - return errors::InvalidArgument("Invalid number of inputs for ", - node_def.op(), ", at ", node_def.name()); - } - if (inputs.at(0).is_weights()) { - // TensorRT will automatically quantize weights, so we will ignore ranges - // for weights. - params->outputs->push_back(inputs.at(0)); - return Status::OK(); + if (node_def.op() == "FakeQuantWithMinMaxArgs") { + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); + } else if (node_def.op() == "FakeQuantWithMinMaxVars") { + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"input", false}, {"min", true}, {"max", true}})); + } else if (node_def.op() == "QuantizeAndDequantizeV2") { + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"input", false}, {"input_min", true}, {"input_max", true}})); + } else if (node_def.op() == "QuantizeAndDequantizeV3") { + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}, + {"input_min", true}, + {"input_max", true}, + {"num_bits", true}})); } float min_range = 0.0f; float max_range = 0.0f; @@ -2050,11 +2599,6 @@ Status ConvertQuantize(OpConverterParams* params) { node_def.op() == "QuantizeAndDequantizeV2" || node_def.op() == "QuantizeAndDequantizeV3") { // Get ranges via inputs. - if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights()) { - return errors::InvalidArgument("Min and max inputs for ", node_def.op(), - " must be weights not tensors, at ", - node_def.name()); - } auto get_weights_value = [&inputs](int index) { auto raw_weights = static_cast( const_cast(inputs.at(index).weights().GetValues())); @@ -2085,20 +2629,11 @@ Status ConvertQuantize(OpConverterParams* params) { return Status::OK(); } -// TODO(pdavoodi): we should update relu6 implementation once TensorRT supports -// Relu6 natively. +// TODO(tmorris): Use ActivationType::kCLIP in TRT 5.1+ once perf improves. tensorflow::Status ConvertRelu6(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 1) { - return tensorflow::errors::InvalidArgument( - "Invalid number of inputs for Relu6, at ", node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - "Relu6 is only implemented for tensors, not weights, at ", - node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); if (params->validation_only) return Status::OK(); // *************************************************************************** // TensorRT does not implement Relu6 natively. This function converts Relu6 op @@ -2122,35 +2657,18 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { params->converter->ProvideQuantizationRange(relu_layer->getOutput(0), 0.0f, 6.0f); - // Create a constant layer to store the floating point weight i.e. 6.0f This - // tensor will be broadcasted uniformly during elementwise `min` operation. - // The constant has to have the same rank as the input in order for TRT to - // broadcast - nvinfer1::Dims dims; - dims.nbDims = relu_layer->getOutput(0)->getDimensions().nbDims; - for (int i = 0; i < dims.nbDims; i++) { - dims.d[i] = 1; - } - TRT_ShapedWeights weights = params->weight_store->GetTempWeights( - tensorflow::DataType::DT_FLOAT, dims); - auto weights_ptr = - static_cast(const_cast(weights.GetValues())); - weights_ptr[0] = 6.0f; - nvinfer1::IConstantLayer* const6_layer = - params->converter->network()->addConstant(dims, weights.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(const6_layer, node_def.name()); - params->converter->ProvideQuantizationRange(const6_layer->getOutput(0), 0.0f, - 6.0f); + // Create a constant layer to store the floating point weight i.e. 6.0f + const nvinfer1::ITensor* const6_tensor = nullptr; + TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant( + params, 6.0f, relu_layer->getOutput(0)->getDimensions(), &const6_tensor)); // ElementWise Min Operation // Min op is a nop for INT8 execution path, as the input tensor // to this layer will only have values in range [0.f, 6.0f]. - const nvinfer1::ITensor* tensor_l = relu_layer->getOutput(0); - const nvinfer1::ITensor* tensor_r = const6_layer->getOutput(0); nvinfer1::IElementWiseLayer* relu6_layer = params->converter->network()->addElementWise( - *const_cast(tensor_l), - *const_cast(tensor_r), + *const_cast(relu_layer->getOutput(0)), + *const_cast(const6_tensor), nvinfer1::ElementWiseOperation::kMIN); TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name()); nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0); @@ -2163,17 +2681,20 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return errors::InvalidArgument("Input expects tensor and weights, at ", - node_def.name()); + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"value", false}, {"bias", true}})); + TFAttrs attrs(node_def); + tensorflow::DataType tf_dtype = attrs.get("T"); + if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) { + return errors::Unimplemented("Data type is not supported, for node ", + node_def.name(), " got ", + DataTypeString(tf_dtype)); } if (params->validation_only) return Status::OK(); nvinfer1::ITensor* tensor = const_cast(inputs.at(0).tensor()); const nvinfer1::Dims original_dims = tensor->getDimensions(); - TFAttrs attrs(node_def); const string data_format = attrs.get("data_format"); const int channel_index = (data_format == "NHWC" ? original_dims.nbDims - 1 : 0); @@ -2219,7 +2740,7 @@ tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { } TRT_ShapedWeights weights = inputs.at(1).weights(); - if (params->converter->precision_mode() == FP16MODE) { + if (params->converter->precision_mode() == TrtPrecisionMode::FP16) { weights = ConvertFP32ToFP16(params->weight_store, weights); } nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL; @@ -2263,43 +2784,69 @@ tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { return Status::OK(); } -Status GetTensorDimsWithProtoShape(const Tensor& tensor, - int tensor_proto_array_len, - nvinfer1::Dims* dims) { +void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) { if (tensor.dims() > 0) { *dims = GetTrtDimsForTensor(tensor); - if (TrtDimsNumElements(*dims) != tensor_proto_array_len && - tensor_proto_array_len != 1) { - return errors::InvalidArgument( - "Broadcast on weights only supports kCHANNEL and kUNIFORM"); - } } else { dims->nbDims = 1; // No dimension provided. Flatten it. - dims->d[0] = tensor_proto_array_len; + dims->d[0] = tensor.NumElements(); dims->type[0] = nvinfer1::DimensionType::kSPATIAL; for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; ++i) { dims->d[i] = 0; } } - return Status::OK(); } -template -Status TfTensorToTrtWeights(const DataType dtype, const Tensor& tensor, - const CType* tensor_proto_array, - int tensor_proto_array_len, TrtWeightStore* store, +Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store, TRT_ShapedWeights* weights) { + const DataType dtype = tensor.dtype(); + + // We always convert the integer constants to INT32, since TRT INT8 is for + // quantized inference. + // + // TODO(aaroey): FP16 will remain in half format and is not converted to + // FP32, but the converter currently uses all float weights as FP32. Fix + // this. + const DataType converted_dtype = + (dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32 + : dtype); + + // Verify that the dtype is supported by TensorRT. Otherwise, return an error. + nvinfer1::DataType trt_dtype; + TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype)); + + if (tensor.NumElements() == 0) { + // Return empty weights having converted dtype. + *weights = TRT_ShapedWeights(converted_dtype); + return Status::OK(); + } + nvinfer1::Dims weight_dims; - TF_RETURN_IF_ERROR(GetTensorDimsWithProtoShape(tensor, tensor_proto_array_len, - &weight_dims)); - *weights = store->GetTempWeights(dtype, weight_dims); - void* dst = const_cast(weights->GetValues()); - if (tensor_proto_array_len == 1) { - std::fill_n((CType*)dst, TrtDimsNumElements(weight_dims), - *tensor_proto_array); + GetTensorDimsWithProtoShape(tensor, &weight_dims); + *weights = weight_store->GetTempWeights(converted_dtype, weight_dims); + + // Copy the tensor directly if the tensor does not require cast to the + // supported type. + if (converted_dtype == dtype) { + char* dst = static_cast(const_cast(weights->GetValues())); + memcpy(dst, tensor.tensor_data().data(), tensor.TotalBytes()); + return Status::OK(); + } + + // Copy tensor elements after casting them to the converted DataType. + int32* dst = static_cast(const_cast(weights->GetValues())); + if (dtype == DT_INT16) { + const int16* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); + } else if (dtype == DT_INT8) { + const int8* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); } else { - memcpy(dst, tensor_proto_array, weights->size_bytes()); + // dtype can only be DT_UINT8 at this point. + TFTRT_CHECK_EQ_TYPE(dtype, DT_UINT8); + const uint8* src = tensor.flat().data(); + std::copy(src, src + tensor.NumElements(), dst); } return Status::OK(); } @@ -2317,15 +2864,6 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { "Constant node is expected to have empty input list: ", node_def.name()); } - TFAttrs attrs(node_def); - const DataType dtype = attrs.get("dtype"); - // We always convert the integer constants to kINT32, since TRT kINT8 is for - // quantized inference. - const DataType converted_dtype = - (dtype == DT_INT16 || dtype == DT_INT8 || dtype == DT_UINT8 ? DT_INT32 - : dtype); - nvinfer1::DataType trt_dtype; - TF_RETURN_IF_ERROR(ConvertDType(converted_dtype, &trt_dtype)); // Create shaped weights as output const auto& tensor_proto = node_def.attr().at("value").tensor(); @@ -2335,78 +2873,18 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { node_def.name()); } - TRT_ShapedWeights weights(converted_dtype); - if (tensor.NumElements() == 0) { - // Do nothing. - } else if (!tensor_proto.float_val().empty()) { - TF_RETURN_IF_ERROR(TfTensorToTrtWeights( - converted_dtype, tensor, tensor_proto.float_val().begin(), - tensor_proto.float_val_size(), params->weight_store, &weights)); - } else if (!tensor_proto.int_val().empty()) { - TF_RETURN_IF_ERROR(TfTensorToTrtWeights( - converted_dtype, tensor, tensor_proto.int_val().begin(), - tensor_proto.int_val_size(), params->weight_store, &weights)); - } else if (!tensor_proto.half_val().empty()) { - // TODO(aaroey): implement fp16 conversion. - return errors::Unimplemented("fp16 constant is not supported yet."); - } else if (!tensor_proto.tensor_content().empty()) { - // TODO(aaroey): fp16 will remain in half format and is not converted to - // fp32, but the converter currently uses all float weights as fp32. Fix - // this. - const auto& content = tensor_proto.tensor_content(); - if (content.size() > 0) { - const int dtype_size = tensorflow::DataTypeSize(dtype); - if (content.size() % dtype_size != 0) { - return errors::FailedPrecondition("Tensor content size ", - content.size(), - " is not a multiple of ", dtype_size); - } - nvinfer1::Dims weights_dim; - TF_RETURN_IF_ERROR(GetTensorDimsWithProtoShape( - tensor, content.size() / dtype_size, &weights_dim)); - const int64_t size_bytes = TrtDimsNumElements(weights_dim) * dtype_size; - if (content.size() != size_bytes) { - return errors::FailedPrecondition( - "Tensor size and TensorProto content size mismatch: ", size_bytes, - " vs ", content.size()); - } else if (tensor.NumElements() != content.size() / dtype_size) { - return errors::FailedPrecondition( - "Tensor elements count and TensorProto content size mismatch: ", - tensor.NumElements(), " vs ", content.size() / dtype_size); - } - weights = - params->weight_store->GetTempWeights(converted_dtype, weights_dim); - if (dtype_size == tensorflow::DataTypeSize(converted_dtype)) { - port::CopyToArray(content, static_cast( - const_cast(weights.GetValues()))); - } else { - // Copy out the weights as original data type. - std::vector temp_weights(content.size()); - port::CopyToArray(content, - reinterpret_cast(temp_weights.data())); - int32* dst = - static_cast(const_cast(weights.GetValues())); - // Copy to the weight store as converted data type. - if (dtype == DT_INT16) { - int16* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else if (dtype == DT_INT8) { - int8* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else if (dtype == DT_UINT8) { - uint8* data = reinterpret_cast(temp_weights.data()); - std::copy(data, data + tensor.NumElements(), dst); - } else { - return errors::FailedPrecondition( - "Unexpected data type: ", DataTypeString(dtype), - " at: ", node_def.name()); - } - } - } - } else { - return errors::Unimplemented("Not supported constant type, at ", - node_def.name()); + TFAttrs attrs(node_def); + const DataType dtype = attrs.get("dtype"); + if (dtype != tensor.dtype()) { + return errors::InvalidArgument("DataType mismatch between attr (", + DataTypeString(dtype), ") and tensor (", + DataTypeString(tensor.dtype()), ")"); } + + TRT_ShapedWeights weights; + TF_RETURN_IF_ERROR( + TfTensorToTrtWeights(tensor, params->weight_store, &weights)); + if (params->outputs != nullptr) { params->outputs->push_back(TRT_TensorOrWeights(weights)); } @@ -2424,9 +2902,13 @@ tensorflow::Status ConvertIdentity(OpConverterParams* params) { Status ConvertBinary(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + // TODO(tmorris): Enable once false is updated to mean either tensor or weight + // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", + // false}})); if (inputs.size() != 2) { - return errors::InvalidArgument("Binary ops require two inputs, at ", - node_def.name()); + return tensorflow::errors::InvalidArgument( + node_def.op(), " got ", inputs.size(), " inputs but expected 2, at ", + node_def.name()); } // Constant folding should have been done by TensorFlow @@ -2476,11 +2958,7 @@ tensorflow::Status ConvertUnary(OpConverterParams* params) { {"Abs", nvinfer1::UnaryOperation::kABS}, {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, }; - - if (inputs.size() != 1) { - return tensorflow::errors::FailedPrecondition( - "Unary ops require single tensor input, at ", node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); // TODO(jie): check type const nvinfer1::ITensor* tensor = nullptr; @@ -2495,7 +2973,7 @@ tensorflow::Status ConvertUnary(OpConverterParams* params) { // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x) // ^ // need range here - if (params->converter->precision_mode() == INT8MODE && + if (params->converter->precision_mode() == TrtPrecisionMode::INT8 && !params->converter->use_calibration()) { return errors::Unimplemented( "Intermediate quantization range cannot be determined without" @@ -2529,14 +3007,7 @@ tensorflow::Status ConvertUnary(OpConverterParams* params) { tensorflow::Status ConvertSquare(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 1) { - return tensorflow::errors::InvalidArgument("Square expects one input, at ", - node_def.name()); - } - if (inputs.at(0).is_weights()) { - return tensorflow::errors::Unimplemented( - "Square is only implemented for tensors, at ", node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); if (params->validation_only) return Status::OK(); // Constant 2 with same rank as input @@ -2549,18 +3020,15 @@ tensorflow::Status ConvertSquare(OpConverterParams* params) { auto weights_ptr = static_cast(const_cast(weights.GetValues())); weights_ptr[0] = 2.f; - nvinfer1::IConstantLayer* const2_layer = - params->converter->network()->addConstant(dims, weights.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(const2_layer, node_def.name()); + nvinfer1::ITensor* const2_tensor = + params->converter->CreateConstantLayer(weights, dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(const2_tensor, node_def.name()); // ElementWise Pow Operation - const nvinfer1::ITensor* tensor_l = inputs.at(0).tensor(); - const nvinfer1::ITensor* tensor_r = const2_layer->getOutput(0); nvinfer1::IElementWiseLayer* layer = params->converter->network()->addElementWise( - *const_cast(tensor_l), - *const_cast(tensor_r), - nvinfer1::ElementWiseOperation::kPOW); + *const_cast(inputs.at(0).tensor()), + *const2_tensor, nvinfer1::ElementWiseOperation::kPOW); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -2571,11 +3039,8 @@ tensorflow::Status ConvertSquare(OpConverterParams* params) { tensorflow::Status ConvertReduce(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at", node_def.name()); - } + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"axis", true}})); const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); TRT_ShapedWeights index_list = inputs.at(1).weights(); @@ -2636,12 +3101,8 @@ tensorflow::Status ConvertReduce(OpConverterParams* params) { tensorflow::Status ConvertPad(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - // TODO(aaroey): make a routine for this check and reuse it. - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at", node_def.name()); - } + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"tensor", false}, {"paddings", true}})); // Implement tensor binaryOp weight [channel wise] for now; const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); @@ -2701,6 +3162,7 @@ tensorflow::Status ConvertPad(OpConverterParams* params) { return tensorflow::errors::Unimplemented( "Padding layer does not support padding on dimension 1 and 3 yet"); } + if (params->validation_only) return Status::OK(); bool legit_pad = true; nvinfer1::DimsHW pre_padding(0, 0); @@ -2804,6 +3266,7 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { inputs_vec.push_back(tensor_i); } + if (params->validation_only) return tensorflow::Status::OK(); // nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); nvinfer1::IConcatenationLayer* layer = @@ -2820,17 +3283,30 @@ tensorflow::Status ConvertConcat(OpConverterParams* params) { tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, + {"scale", true}, + {"offset", true}, + {"mean", true}, + {"variance", true}})); TFAttrs attrs(node_def); float epsilon = attrs.get("epsilon"); auto data_format = attrs.get("data_format"); if (data_format != "NCHW") { return tensorflow::errors::Unimplemented( - "only data_format=NCHW is supported, at " + node_def.name()); + node_def.op(), " only supports data_format=NCHW, at ", node_def.name()); } bool is_training = attrs.get("is_training"); if (is_training) { + // Trying to use batchnorm in training mode is a very common problem. + // Because the error message will only be printed in VLOG(1) by the + // segmenter, we issue a special warning so that users will actually see it. + LOG(WARNING) << node_def.op() << " only supports is_training=false. If you " + << "are using Keras, please call " + << "keras.backend.set_learning_phase(0) before constructing " + << "your model. At " << node_def.name(); return tensorflow::errors::Unimplemented( - "only is_training=false is supported, at " + node_def.name()); + node_def.op(), " only supports is_training=false, at ", + node_def.name()); } nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); @@ -2845,7 +3321,7 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { for (int i = 1; i < 5; i++) { if (inputs.at(i).weights().type_ != parameter_type) { return tensorflow::errors::Unimplemented( - "Inconsistent parameter type for batchnormis not supported, at: " + + "Inconsistent parameter type for batchnorm is not supported, at: " + node_def.name()); } } @@ -2865,6 +3341,8 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { "Inconsistent batchnorm parameter count, at: " + node_def.name()); } } + if (params->validation_only) return Status::OK(); + // We could technically have two weights with different shape. // that requires two addScale op, arguably less performant TRT_ShapedWeights combined_scale_weights = @@ -2986,14 +3464,9 @@ tensorflow::Status ConvertMatMulHelper(OpConverterParams* params, tensorflow::Status ConvertMatMul(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { - return errors::InvalidArgument("Input expects tensor and weights, at ", - node_def.name()); - } + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"a", false}, {"b", true}})); TFAttrs attrs(node_def); - // TODO(jie): INT32 should be converted? tensorflow::DataType tf_dtype = attrs.get("T"); if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) { return errors::Unimplemented("Data type is not supported, for node ", @@ -3017,9 +3490,16 @@ tensorflow::Status ConvertMatMul(OpConverterParams* params) { tensorflow::Status ConvertBatchMatMul(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + // TODO(tmorris): Enable once false is updated to mean either tensor or weight + // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y", + // false}})); + if (inputs.size() != 2) { + return tensorflow::errors::InvalidArgument( + node_def.op(), " got ", inputs.size(), " inputs but expected 2, at ", + node_def.name()); + } TFAttrs attrs(node_def); - // TODO(jie): INT32 should be converted? tensorflow::DataType tf_dtype = attrs.get("T"); if (tf_dtype != tensorflow::DataType::DT_FLOAT && tf_dtype != tensorflow::DataType::DT_HALF) { @@ -3089,6 +3569,7 @@ tensorflow::Status ConvertBatchMatMul(OpConverterParams* params) { tensorflow::Status ConvertSoftmax(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"logits", false}})); const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); int nbDims = tensor->getDimensions().nbDims; @@ -3097,6 +3578,8 @@ tensorflow::Status ConvertSoftmax(OpConverterParams* params) { "TensorRT Softmax cannot apply on batch dimension, at" + node_def.name()); } + if (params->validation_only) return Status::OK(); + nvinfer1::ISoftMaxLayer* layer = params->converter->network()->addSoftMax( *const_cast(tensor)); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); @@ -3112,31 +3595,36 @@ tensorflow::Status ConvertSoftmax(OpConverterParams* params) { tensorflow::Status ConvertTopK(OpConverterParams* params) { const auto& inputs = params->inputs; + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) { + return errors::InvalidArgument("Input expects tensor and weights, at ", + params->node_def.name()); + } + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"k", true}})); const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - - int nbDims = tensor->getDimensions().nbDims; - if (nbDims == 0) { - return tensorflow::errors::InvalidArgument( - "TensorRT TopK cannot apply on batch dimension, at" + node_def.name()); + const int num_dims = tensor->getDimensions().nbDims; + if (num_dims == 0) { + return errors::InvalidArgument( + "TensorRT TopK cannot apply on batch dimension, at", node_def.name()); } TRT_ShapedWeights k_w = inputs.at(1).weights(); - int k = *(static_cast(const_cast(k_w.GetValues()))); - - nvinfer1::TopKOperation op; - uint32_t reducedAxes = 0; - if (node_def.op() == "TopKV2") { - op = nvinfer1::TopKOperation::kMAX; - reducedAxes |= 1 << (nbDims - 1); - } else { - return tensorflow::errors::Unimplemented( - "Operation: " + node_def.op() + - " not implemented, at: " + node_def.name()); + if (k_w.count() != 1) { + return errors::InvalidArgument("k value of TopK should be a scalar, at", + node_def.name()); } + // Note that ITopKLayer always have sorted outputs, so we don't need to handle + // the 'sorted' attribute of the node. + if (params->validation_only) return Status::OK(); + const nvinfer1::TopKOperation op = nvinfer1::TopKOperation::kMAX; + const int k = *(static_cast(const_cast(k_w.GetValues()))); + const uint32_t reduce_axes = 1 << (num_dims - 1); nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK( - *const_cast(tensor), op, k, reducedAxes); + *const_cast(tensor), op, k, reduce_axes); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_value_tensor = layer->getOutput(0); @@ -3150,12 +3638,22 @@ static void RegisterValidatableOpConverters( std::unordered_map* registration) { // TODO(laigd): support all op types. (*registration)["BiasAdd"] = ConvertBiasAdd; + (*registration)["ConcatV2"] = ConvertConcat; (*registration)["Const"] = ConvertConst; - (*registration)["Transpose"] = ConvertTranspose; - (*registration)["Reshape"] = ConvertReshape; + (*registration)["Conv2D"] = ConvertConv2D; + (*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput; + (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; + (*registration)["ExpandDims"] = ConvertExpandDims; + (*registration)["LeakyRelu"] = ConvertLeakyRelu; (*registration)["MatMul"] = ConvertMatMul; + (*registration)["Pad"] = ConvertPad; (*registration)["Relu6"] = ConvertRelu6; + (*registration)["Reshape"] = ConvertReshape; (*registration)["Square"] = ConvertSquare; + (*registration)["Squeeze"] = ConvertSqueeze; + (*registration)["StridedSlice"] = ConvertStridedSlice; + (*registration)["Transpose"] = ConvertTranspose; + (*registration)["TopKV2"] = ConvertTopK; for (auto quantization_op_type : {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", @@ -3169,6 +3667,12 @@ static void RegisterValidatableOpConverters( for (auto activation_op_type : {"Relu", "Sigmoid", "Tanh"}) { (*registration)[activation_op_type] = ConvertActivation; } + for (auto pool_op_type : {"AvgPool", "MaxPool"}) { + (*registration)[pool_op_type] = ConvertPool; + } + for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) { + (*registration)[normalization_op_type] = ConvertFusedBatchNorm; + } } void TrtNodeValidator::RegisterOpValidators() { @@ -3177,21 +3681,10 @@ void TrtNodeValidator::RegisterOpValidators() { void Converter::RegisterOpConverters() { RegisterValidatableOpConverters(&op_registry_); - - op_registry_["Conv2D"] = ConvertConv2D; - op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; - op_registry_["MaxPool"] = ConvertPool; - op_registry_["AvgPool"] = ConvertPool; // TODO(ben,jie): this is a temp hack. op_registry_["Identity"] = ConvertIdentity; // Identity should be removed op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed - op_registry_["Pad"] = ConvertPad; - - op_registry_["ConcatV2"] = ConvertConcat; - op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; - op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; - op_registry_["Rsqrt"] = ConvertUnary; op_registry_["Reciprocal"] = ConvertUnary; op_registry_["Exp"] = ConvertUnary; @@ -3207,14 +3700,13 @@ void Converter::RegisterOpConverters() { op_registry_["Mean"] = ConvertReduce; op_registry_["Softmax"] = ConvertSoftmax; op_registry_["BatchMatMul"] = ConvertBatchMatMul; - op_registry_["TopKV2"] = ConvertTopK; plugin_converter_ = ConvertPlugin; } tensorflow::Status ConvertGraphDefToEngine( - const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, - size_t max_workspace_size_bytes, + const tensorflow::GraphDef& gdef, TrtPrecisionMode precision_mode, + int max_batch_size, size_t max_workspace_size_bytes, const std::vector& input_shapes, Logger* logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, @@ -3229,9 +3721,9 @@ tensorflow::Status ConvertGraphDefToEngine( builder->setMaxBatchSize(max_batch_size); builder->setMaxWorkspaceSize(max_workspace_size_bytes); builder->setGpuAllocator(allocator); - if (precision_mode == FP16MODE) { + if (precision_mode == TrtPrecisionMode::FP16) { builder->setHalf2Mode(true); - } else if (precision_mode == INT8MODE) { + } else if (precision_mode == TrtPrecisionMode::INT8) { builder->setInt8Mode(true); if (use_calibration) { builder->setInt8Calibrator(calibrator); @@ -3251,15 +3743,14 @@ tensorflow::Status ConvertGraphDefToEngine( // Build the network VLOG(1) << "Starting engine conversion "; Converter converter(trt_network.get(), precision_mode, use_calibration); - std::vector> output_tensors; + std::vector output_tensors; // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { string node_name = node_def.name(); VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op(); - if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && - (node_def.op() == "Placeholder")) { + if (IsEngineInput(node_name) && (node_def.op() == "Placeholder")) { int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32( + if (!tensorflow::strings::safe_strto32( // non-absl ok node_name.c_str() + strlen(kInputPHName), &slot_number)) { return tensorflow::errors::InvalidArgument( "Failed to parse slot number from ", node_name); @@ -3285,18 +3776,23 @@ tensorflow::Status ConvertGraphDefToEngine( // engines offline, by calling sess.run() and cache/serialize the engines. TF_RETURN_IF_ERROR( converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size)); - } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && - (node_def.op() == "Identity")) { + } else if (IsEngineOutput(node_name) && (node_def.op() == "Identity")) { int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32( + if (!tensorflow::strings::safe_strto32( // non-absl ok node_name.c_str() + strlen(kOutputPHName), &slot_number)) { return tensorflow::errors::InvalidArgument( "Failed to parse slot number from ", node_name); } + // Get output type that TensorFlow expects + TFAttrs attrs(node_def); + tensorflow::DataType tf_dtype = attrs.get("T"); + nvinfer1::DataType trt_dtype; + TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); if (output_tensors.size() <= slot_number) { output_tensors.resize(slot_number + 1); } - output_tensors.at(slot_number) = {node_def.input(0), node_name}; + output_tensors.at(slot_number) = {node_def.input(0), node_name, + trt_dtype}; } else { VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op(); @@ -3322,8 +3818,7 @@ tensorflow::Status ConvertGraphDefToEngine( tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& subgraph_node_names, - const std::vector& subgraph_node_ids, // In topological order + const std::vector& subgraph_nodes, // In topological order std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope) { std::set marker_nodes; @@ -3386,8 +3881,10 @@ tensorflow::Status ConvertSegmentToGraphDef( marker_nodes.insert(node_name); auto seg_node = segment_def->add_node(); tensorflow::NodeDefBuilder builder(node_name, "Identity"); - auto status = builder.Input(connection.inside_node_name, 0, dtype) - .Finalize(seg_node); + auto status = + builder + .Input(connection.inside_node_name, connection.inside_port, dtype) + .Finalize(seg_node); VLOG(1) << "Constructing output " << node_name << " for the edge " << connection.inside_node_name << ":" << connection.inside_port << " -> " << connection.outside_node_name << ":" @@ -3397,11 +3894,10 @@ tensorflow::Status ConvertSegmentToGraphDef( std::unordered_map old_to_new_id_map; // Copy internal nodes to new graphdef - string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name(); - for (const auto node_id : subgraph_node_ids) { - const auto node = graph->FindNodeId(node_id); + string local_scope = subgraph_nodes.front()->name(); + for (const Node* node : subgraph_nodes) { local_scope = GetCommonNameScope(local_scope, node->name()); - old_to_new_id_map[node_id] = segment_def->node_size(); + old_to_new_id_map[node->id()] = segment_def->node_size(); auto snode = segment_def->add_node(); snode->CopyFrom(node->def()); VLOG(2) << "Copying " << snode->name() << " to subgraph"; @@ -3419,6 +3915,11 @@ tensorflow::Status ConvertSegmentToGraphDef( << placeholder_name; snode->set_input(connection.inside_port, placeholder_name); } + std::set subgraph_node_names; + for (const Node* node : subgraph_nodes) { + subgraph_node_names.insert(node->name()); + } + // Remove control inputs that are not inside the segment. for (int i = 0; i < segment_def->node_size(); ++i) { auto snode = segment_def->mutable_node(i); @@ -3429,7 +3930,7 @@ tensorflow::Status ConvertSegmentToGraphDef( TensorId input = ParseTensorName(snode->input(input_idx)); if (!subgraph_node_names.count( string(input.first.data(), input.first.size())) && - !str_util::StartsWith(input.first, kInputPHName)) { + !IsEngineInput(input.first)) { if (input.second == Graph::kControlSlot) { VLOG(1) << "... removing control inputs " << input.first << " from subgraph."; diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h similarity index 88% rename from tensorflow/contrib/tensorrt/convert/convert_nodes.h rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 54e19b73957bccdae2b23bd3556de9ad00b864e5..d1e30eb848bd6ab62719ca6da561d14b05d8537d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ #include #include @@ -22,11 +22,11 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -92,7 +92,7 @@ struct EngineInfo { EngineInfo() : engine_type(EngineType::TRTStatic), max_workspace_size_bytes(0), - precision_mode(FP32MODE), + precision_mode(TrtPrecisionMode::FP32), use_calibration(true) {} string engine_name; @@ -109,7 +109,7 @@ struct EngineInfo { int64 max_workspace_size_bytes; int maximum_cached_engines; std::vector cached_engine_batches; - int precision_mode; + TrtPrecisionMode precision_mode; bool use_calibration; }; @@ -128,8 +128,7 @@ struct EngineInfo { tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, - const std::set& subgraph_node_names, - const std::vector& subgraph_node_ids, + const std::vector& subgraph_nodes, std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope); @@ -142,8 +141,8 @@ tensorflow::Status ConvertSegmentToGraphDef( // is successful. This is different than successfully building the engine: // building can still fail afterwards. tensorflow::Status ConvertGraphDefToEngine( - const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, - size_t max_workspace_size_bytes, + const tensorflow::GraphDef& gdef, TrtPrecisionMode precision_mode, + int max_batch_size, size_t max_workspace_size_bytes, const std::vector& input_shapes, Logger* logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, @@ -159,7 +158,10 @@ class OutputEdgeValidator { bool operator()(const tensorflow::Edge* out_edge) const; }; +string DebugString(const nvinfer1::DimensionType type); +string DebugString(const nvinfer1::DataType trt_dtype); string DebugString(const nvinfer1::Dims& dims); +string DebugString(const nvinfer1::Permutation& permutation, int len); string DebugString(const nvinfer1::ITensor& tensor); int64_t TrtDimsNumElements(const nvinfer1::Dims& dims); @@ -176,6 +178,8 @@ class TRT_ShapedWeights { nvinfer1::Weights GetTrtWeights() const; + // Returns the raw pointer to the underlying buffer which holds the weights + // value. void* GetValues() const { return const_cast(tensor_.tensor_data().data()); } @@ -195,6 +199,10 @@ class TRT_ShapedWeights { // underlying buffer. TRT_ShapedWeights(DataType type, nvinfer1::Dims dims, Tensor tensor); + // All weights should be stored inside TrtWeightStore to make sure lifetime of + // all the underlying tensors are available until the engine is built. For + // this reason, tensor_ should never be reassigned to a different value that + // is not already present in the TrtWeightStore. Tensor tensor_; friend class TrtWeightStore; @@ -394,8 +402,21 @@ class TrtNodeValidator { // Class to convert TF nodes to TRT network. class Converter { public: - Converter(nvinfer1::INetworkDefinition* trt_network, int precision_mode, - bool use_calibration); + // Used for Converter::RenameAndMarkOutputTensors() + struct EngineOutputInfo { + // The TRT tensor name which produces the output. + string source_tensor_name; + // The TensorFlow node name which is receiving the output from the TRT + // engine. This should always be the Identity node created in + // ConvertSegmentToGraphDef. + string dest_node_name; + // Output type. TensorRT requires this to be explicitly set for engine + // outputs. + nvinfer1::DataType trt_dtype; + }; + + Converter(nvinfer1::INetworkDefinition* trt_network, + TrtPrecisionMode precision_mode, bool use_calibration); ////////////////////////////////////////////////////////////////////////////// // Methods used by the TRT engine builder to build a TRT network from a TF @@ -409,13 +430,10 @@ class Converter { Status AddInputTensor(const string& name, nvinfer1::DataType dtype, const nvinfer1::Dims& dims, int batch_size); - // Mark the tensors with names specified by output_tensors[i].first as output - // of the TRT network, and set their names in the TRT network as - // output_tensors[i].second. The tensor names (output_tensors[i].first) are - // standard TF tensor names, i.e. node names followed by output slot number - // (or just the node name if the tensor is the first output of the node). + // Mark the tensors with names specified by source_tensor_name as output of + // the TRT network, and set their names in the TRT network as dest_node_name. Status RenameAndMarkOutputTensors( - const std::vector>& output_tensors); + const std::vector& output_tensors); ////////////////////////////////////////////////////////////////////////////// // Methods used by op converters to convert individual TF node and add layers @@ -426,7 +444,7 @@ class Converter { nvinfer1::INetworkDefinition* network() { return trt_network_; } // What precision are we targeting? - int precision_mode() const { return precision_mode_; } + TrtPrecisionMode precision_mode() const { return precision_mode_; } // Calibration will be or was previously performed on this network? bool use_calibration() const { return use_calibration_; } @@ -469,6 +487,11 @@ class Converter { nvinfer1::Dims* operand_l_new_dims, nvinfer1::Dims* operand_r_new_dims) const; + // Creates an IConstantLayer using 'weights' whose dimensions are specified by + // 'dims', and returns the output ITensor. + nvinfer1::ITensor* CreateConstantLayer(const TRT_ShapedWeights& weights, + const nvinfer1::Dims& dims); + private: // Verify the provided batch_size is consistent with batch_size_ and update it // if necessary. @@ -523,7 +546,7 @@ class Converter { std::vector> quantization_infer_; - const int precision_mode_; + const TrtPrecisionMode precision_mode_; const bool use_calibration_; @@ -544,4 +567,4 @@ class Converter { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc similarity index 59% rename from tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc rename to tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 443033379f0d6554784d44412a02aa8cb035ab08..77221f6d9a42a165e8f9e322e1f876b02f4db59f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include #include @@ -21,11 +21,16 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT @@ -35,7 +40,9 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/public/session.h" @@ -50,7 +57,7 @@ namespace tensorflow { namespace tensorrt { namespace convert { -using ::tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::ElementsAre; using ::testing::ElementsAreArray; @@ -152,7 +159,7 @@ void ExpectTrtDimsEqualsArray(const std::vector& lhs, } template -void ExpectArrayNear(const std::vector& lhs, const std::vector& rhs) { +void ExpectArrayNear(const std::vector& lhs, absl::Span rhs) { ASSERT_EQ(lhs.size(), rhs.size()); for (int i = 0; i < lhs.size(); i++) { EXPECT_FLOAT_EQ(lhs[i], rhs[i]); @@ -163,7 +170,7 @@ void ExpectArrayNear(const std::vector& lhs, const std::vector& rhs) { // EXPECT_FLOAT_EQ. template <> void ExpectArrayNear(const std::vector& lhs, - const std::vector& rhs) { + absl::Span rhs) { ASSERT_EQ(lhs.size(), rhs.size()); for (int i = 0; i < lhs.size(); i++) { EXPECT_FLOAT_EQ(Eigen::half_impl::half_to_float(lhs[i]), @@ -364,9 +371,6 @@ TEST(TRT_TensorOrWeights_Test, Basic) { EXPECT_EQ(false, ptr->is_tensor()); EXPECT_EQ(true, ptr->is_weights()); EXPECT_TRUE(TrtShapedWeightsEquals(weights, ptr->weights())); - - nvinfer1::Dims dims; - dims.nbDims = 0; ExpectTrtDimsEqualsArray({}, ptr->GetTrtDims()); } } @@ -481,8 +485,7 @@ class ConverterTest : public ::testing::Test { ConverterTest() { builder_.reset(nvinfer1::createInferBuilder(logger_)); network_.reset(builder_->createNetwork()); - converter_.reset(new Converter(network_.get(), - /*precision_mode=*/FP32MODE, + converter_.reset(new Converter(network_.get(), TrtPrecisionMode::FP32, /*use_calibration=*/false)); weight_store_ = &converter_->weight_store_; } @@ -784,7 +787,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { // input -> infer1 -> infer2 -> infer3 FakeITensor input, infer_1, infer_2, infer_3; FakeITensor not_infer; - Converter int8_converter(/*trt_network=*/nullptr, INT8MODE, + Converter int8_converter(/*trt_network=*/nullptr, TrtPrecisionMode::INT8, /*use_calibration=*/true); int8_converter.ProvideQuantizationRange(&input, -5.0f, 5.0f); int8_converter.ProvideQuantizationRange(¬_infer, -100.0f, 100.0f); @@ -915,6 +918,97 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { "(tensor #dims 4 vs broadcast #dims 5)"); } +TEST_F(ConverterTest, CreateConstantLayer) { + for (auto dtype : {DT_FLOAT, DT_INT32}) { + TRT_ShapedWeights weights = + weight_store_->GetTempWeights(dtype, GetTestDims({2, 3, 5})); + nvinfer1::ITensor* tensor = + converter_->CreateConstantLayer(weights, GetTestDims({3, 10})); + ASSERT_NE(nullptr, tensor); + EXPECT_EQ(TfDataTypeToTrt(dtype), tensor->getType()) + << "Expected " << DebugString(TfDataTypeToTrt(dtype)) << " vs. actual " + << DebugString(tensor->getType()); + ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions()); + } +} + +class ConvertGraphDefToEngineTest : public ::testing::Test { + public: + Status RunConvertGraphDefToEngine(Scope* s) { + GraphDef gdef; + TF_EXPECT_OK(s->ToGraphDef(&gdef)); + std::vector input_shapes; + int batch_size = -1; + for (const NodeDef& node : gdef.node()) { + absl::string_view node_name(node.name()); + if (str_util::ConsumePrefix(&node_name, kInputPHName)) { + int port = -1; + EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name(); + if (input_shapes.size() < port + 1) input_shapes.resize(port + 1); + input_shapes[port] = + PartialTensorShape(node.attr().at("shape").shape()); + if (batch_size == -1) { + batch_size = input_shapes[port].dim_size(0); + } else { + EXPECT_EQ(batch_size, input_shapes[port].dim_size(0)); + } + } + } + // TODO(laigd): execute the engine and get outputs. + return ConvertGraphDefToEngine( + gdef, TrtPrecisionMode::FP32, /*max_batch_size=*/1, + /*max_workspace_size_bytes=*/64 << 20, input_shapes, &logger_, + /*allocator=*/nullptr, /*calibrator=*/nullptr, &engine_, + /*use_calibration=*/false, /*convert_successfully=*/nullptr); + } + + protected: + TrtUniquePtrType engine_; + + private: + Logger logger_; +}; + +TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName(StrCat(kInputPHName, 0)), DT_FLOAT, + ops::Placeholder::Shape({1, 1})); + auto output = ops::Identity(s.WithOpName("identity1"), input); + output = ops::Identity(s.WithOpName("identity2"), output); + output = ops::Identity(s.WithOpName(StrCat(kOutputPHName, 0)), output); + // If the converter marks the input tensor as output tensor, the conversion + // below will fail with: + // > TensorRTOutputPH_0 cannot be both input and output + // > Network must have at least one output + TF_EXPECT_OK(RunConvertGraphDefToEngine(&s)); +} + +// Input/output data format for OpConverterTest::BuildAndRun(). +struct InputOutputData { + void* Buffer() const { + return const_cast(tensor.tensor_data().data()); + } + + size_t TotalBytes() const { return tensor.TotalBytes(); } + + const char* name; + Tensor tensor; +}; + +template +Tensor ConstructTensor(int data_size, const T& value = T()) { + std::vector values(data_size, value); + return test::AsTensor(values); +} + +using DataVec = std::vector; + +template +inline absl::Span GetSpanForData(const InputOutputData& data) { + const auto& tensor_map = data.tensor.flat(); + return absl::Span(tensor_map.data(), tensor_map.size()); +} + // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { @@ -940,11 +1034,11 @@ class OpConverterTest : public ::testing::Test { builder_.reset(nvinfer1::createInferBuilder(logger_)); network_.reset(builder_->createNetwork()); builder_->setMaxBatchSize(1); + builder_->setMaxWorkspaceSize(1 << 26); // Reset the validator and converter. validator_.reset(new TrtNodeValidator); - converter_.reset(new Converter(network_.get(), - /*precision_mode=*/FP32MODE, + converter_.reset(new Converter(network_.get(), TrtPrecisionMode::FP32, /*use_calibration=*/false)); // Reset other related artifacts. @@ -953,14 +1047,14 @@ class OpConverterTest : public ::testing::Test { } // TODO(laigd): test fp16 and int8 support. - template - void BuildAndRun( - const std::vector>>& - input_data, - const char* output_name, std::vector* output_data) { + void BuildAndRun(const DataVec& input_data, DataVec* output_data) { // Mark the output tensor as TRT engine output. - TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors( - {{string(output_name), string(output_name)}})); + std::vector output_info; + for (const auto& data : *output_data) { + output_info.push_back( + {data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())}); + } + TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info)); // Build the TRT engine. ASSERT_EQ(nullptr, engine_.get()); @@ -968,31 +1062,44 @@ class OpConverterTest : public ::testing::Test { CHECK_NOTNULL(engine_.get()); // Execute the TRT engine. - ASSERT_LE(input_data.size() + 1, 3); - void* buffers[3]; - for (const auto name_and_data : input_data) { - const int input_size = name_and_data.second.size() * sizeof(T); - const int input_index = engine_->getBindingIndex(name_and_data.first); - ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size)); - ASSERT_EQ( - 0, cudaMemcpyAsync(buffers[input_index], name_and_data.second.data(), - input_size, cudaMemcpyHostToDevice, stream_)); + const int num_bindings = input_data.size() + output_data->size(); + std::vector buffers(num_bindings); + + for (const auto& data : input_data) { + const int input_index = engine_->getBindingIndex(data.name); + ASSERT_EQ(0, cudaMalloc(&buffers[input_index], data.TotalBytes())); + ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], data.Buffer(), + data.TotalBytes(), cudaMemcpyHostToDevice, + stream_)); + } + struct SizeAndIndex { + SizeAndIndex(int in_size, int in_index) + : size(in_size), index(in_index) {} + int size; + int index; + }; + std::vector output_infos; + for (const auto& data : *output_data) { + const int output_index = engine_->getBindingIndex(data.name); + output_infos.emplace_back(data.TotalBytes(), output_index); + ASSERT_EQ(0, cudaMalloc(&buffers[output_index], data.TotalBytes())); } - const int output_size = output_data->size() * sizeof(T); - const int output_index = engine_->getBindingIndex(output_name); - ASSERT_EQ(0, cudaMalloc(&buffers[output_index], output_size)); - - ASSERT_EQ(engine_->getNbBindings(), input_data.size() + 1); - + ASSERT_EQ(engine_->getNbBindings(), num_bindings); TrtUniquePtrType execution_context( engine_->createExecutionContext()); - execution_context->enqueue(/*batchSize=*/1, buffers, stream_, nullptr); - ASSERT_EQ(0, cudaMemcpyAsync(output_data->data(), buffers[output_index], - output_size, cudaMemcpyDeviceToHost, stream_)); + execution_context->enqueue(/*batchSize=*/1, buffers.data(), stream_, + nullptr); + + for (int i = 0; i < output_infos.size(); ++i) { + const auto& output_info = output_infos[i]; + ASSERT_EQ(0, cudaMemcpyAsync(output_data->at(i).Buffer(), + buffers[output_info.index], output_info.size, + cudaMemcpyDeviceToHost, stream_)); + } cudaStreamSynchronize(stream_); - for (int i = 0; i < input_data.size() + 1; ++i) { + for (int i = 0; i < num_bindings; ++i) { ASSERT_EQ(0, cudaFree(buffers[i])); } } @@ -1111,6 +1218,30 @@ class OpConverterTest : public ::testing::Test { std::unordered_map validator_inputs_; }; +template +void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField* out) { + out->Clear(); + if (tensor.NumElements() == 0) return; + + // TensorProto does not need to have all the elements present and can truncate + // trailing elements with the same value for compressed representation. Such + // elements are derived based on the tensor shape. + const auto flat = tensor.flat(); + int64 last_index = 0; + for (int64 i = 0; i < tensor.NumElements(); ++i) { + if (flat(i) != flat(last_index)) { + last_index = i; + } + } + + int num_out_elements = last_index + 1; + out->Reserve(num_out_elements); + out->AddNAlreadyReserved(num_out_elements); + const T* src = flat.data(); + T* dst = out->mutable_data(); + std::copy(src, src + num_out_elements, dst); +} + template void TestConvertConst(OpConverterTest* test) { NodeDef node_def; @@ -1123,11 +1254,23 @@ void TestConvertConst(OpConverterTest* test) { const std::vector& expected_value) { test->Reset(); - auto& attr = *node_def.mutable_attr(); + TensorProto* tensor_attr = + (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor_attr->Clear(); + if (as_tensor_content) { - tensor.AsProtoTensorContent(attr["value"].mutable_tensor()); + tensor.AsProtoTensorContent(tensor_attr); } else { - tensor.AsProtoField(attr["value"].mutable_tensor()); + tensor.shape().AsProto(tensor_attr->mutable_tensor_shape()); + tensor_attr->set_dtype(tensor.dtype()); + + if (tensor.dtype() == DT_FLOAT) { + CopyTensorElements(tensor, tensor_attr->mutable_float_val()); + } else if (tensor.dtype() == DT_INT32) { + CopyTensorElements(tensor, tensor_attr->mutable_int_val()); + } else { + tensor.AsProtoField(tensor_attr); + } } test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; @@ -1140,8 +1283,7 @@ void TestConvertConst(OpConverterTest* test) { { // By default empty tensor will pick DT_FLOAT as data type and we fix it // here. - attr["value"].mutable_tensor()->set_dtype(dtype); - Tensor t; // Empty tensor. + Tensor t(dtype); // Empty tensor. reset_and_test(t, false, {}, {}); } { @@ -1160,6 +1302,22 @@ void TestConvertConst(OpConverterTest* test) { reset_and_test(t, false, {2, 3}, {1, 2, 3, 4, 5, 6}); reset_and_test(t, true, {2, 3}, {1, 2, 3, 4, 5, 6}); } + { + // Set all tensor elements to the same value. Such tensors are encoded + // using a single element list in tensor proto. + Tensor t = ::tensorflow::test::AsTensor({1, 1, 1, 1, 1, 1}, + TensorShape({2, 3})); + reset_and_test(t, false, {2, 3}, {1, 1, 1, 1, 1, 1}); + reset_and_test(t, true, {2, 3}, {1, 1, 1, 1, 1, 1}); + } + { + // Set trailing tensor elements to the same value. Such tensors are + // encoded by truncating all equal elements except the first one. + Tensor t = ::tensorflow::test::AsTensor({2, 2, 1, 1, 1, 1}, + TensorShape({2, 3})); + reset_and_test(t, false, {2, 3}, {2, 2, 1, 1, 1, 1}); + reset_and_test(t, true, {2, 3}, {2, 2, 1, 1, 1, 1}); + } } TEST_F(OpConverterTest, ConvertConst) { @@ -1189,7 +1347,7 @@ TEST_F(OpConverterTest, ConvertTranspose) { NodeDef node_def = MakeNodeDef("my_transpose", "Transpose", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_transpose"); + "Transpose got 0 inputs but expected 2, at my_transpose"); } // Get the NodeDef for Transpose. @@ -1205,8 +1363,8 @@ TEST_F(OpConverterTest, ConvertTranspose) { AddTestTensor("input", {1, 2, 3}); AddTestTensor("weights", {3}); RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_transpose"); + node_def, error::UNIMPLEMENTED, + "The input \"perm\" for Transpose must be a constant, at my_transpose"); } { // Transpose at batch dimension, should fail. @@ -1236,10 +1394,12 @@ TEST_F(OpConverterTest, ConvertTranspose) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(6); - BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_transpose", - &output_data); - EXPECT_THAT(output_data, ElementsAre(1, 4, 2, 5, 3, 6)); + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_transpose", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 4, 2, 5, 3, 6)); } } @@ -1249,7 +1409,7 @@ TEST_F(OpConverterTest, ConvertReshape) { NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects weights for shape, at my_reshape"); + "Reshape got 0 inputs but expected 2, at my_reshape"); } // Get the NodeDef for Reshape. @@ -1265,8 +1425,8 @@ TEST_F(OpConverterTest, ConvertReshape) { AddTestTensor("input", {1, 2, 3}); AddTestTensor("weights", {3}); RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Input expects weights for shape, at my_reshape"); + node_def, error::UNIMPLEMENTED, + "The input \"shape\" for Reshape must be a constant, at my_reshape"); } { // Reshape to scalar, should fail. @@ -1326,10 +1486,12 @@ TEST_F(OpConverterTest, ConvertReshape) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions()); - std::vector output_data(6); - BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_reshape", - &output_data); - EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_reshape", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 2, 3, 4, 5, 6)); } } @@ -1339,7 +1501,7 @@ TEST_F(OpConverterTest, ConvertMatMul) { NodeDef node_def = MakeNodeDef("my_matmul", "MatMul", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_matmul"); + "MatMul got 0 inputs but expected 2, at my_matmul"); } // Get the NodeDef for MatMul. @@ -1389,12 +1551,13 @@ TEST_F(OpConverterTest, ConvertMatMul) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); - std::vector output_data(2); - BuildAndRun({{"input", {0, 1}}}, "my_matmul", &output_data); + const DataVec input_data{{"input", test::AsTensor({0, 1})}}; + DataVec output_data{{"my_matmul", ConstructTensor(2)}}; + BuildAndRun(input_data, &output_data); if (transpose_b) { - EXPECT_THAT(output_data, ElementsAre(1, 3)); + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); } else { - EXPECT_THAT(output_data, ElementsAre(2, 3)); + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(2, 3)); } } } @@ -1448,23 +1611,28 @@ void TestConvertBiasAdd(OpConverterTest* test) { const int num_input = TrtDimsNumElements(GetTestDims(dims_array)); ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2), num_input); - std::vector output_data(num_input); - test->BuildAndRun( - {{"input", std::vector(num_input, CType(0))}}, "my_biasadd", - &output_data); + + const DataVec input_data{ + {"input", ConstructTensor(num_input, CType(0))}}; + DataVec output_data{{"my_biasadd", ConstructTensor(num_input)}}; + test->BuildAndRun(input_data, &output_data); if (trt_input_rank == 1) { if (data_format == "NHWC") { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(2), CType(3))); } else { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(2))); } } else { if (data_format == "NHWC") { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3), - CType(1), CType(2), CType(3))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(2), CType(3), CType(1), + CType(2), CType(3))); } else { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(1), CType(1), - CType(2), CType(2), CType(2))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(1), CType(1), CType(2), + CType(2), CType(2))); } } } @@ -1477,7 +1645,7 @@ TEST_F(OpConverterTest, ConvertBiasAdd) { NodeDef node_def = MakeNodeDef("my_biasadd", "BiasAdd", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Input expects tensor and weights, at my_biasadd"); + "BiasAdd got 0 inputs but expected 2, at my_biasadd"); } // OK. Note that kINT32 is not supported by IScaleLayer, so we don't test @@ -1542,21 +1710,25 @@ void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(2); - test->BuildAndRun( - {{"input", - /*input_data=*/swap_inputs ? operand2 : operand1}}, - "my_binary", &output_data); + const DataVec input_data{ + {"input", test::AsTensor(swap_inputs ? operand2 : operand1)}}; + DataVec output_data{{"my_binary", ConstructTensor(2)}}; + test->BuildAndRun(input_data, &output_data); if (node_def.op() == "Add") { - EXPECT_THAT(output_data, ElementsAre(CType(5), CType(10.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(5), CType(10.5))); } else if (node_def.op() == "Sub") { - EXPECT_THAT(output_data, ElementsAre(CType(1), CType(4.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1), CType(4.5))); } else if (node_def.op() == "Mul") { - EXPECT_THAT(output_data, ElementsAre(CType(6), CType(22.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(6), CType(22.5))); } else if (node_def.op() == "Div") { - EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1.5), CType(2.5))); } else if (node_def.op() == "RealDiv") { - EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(CType(1.5), CType(2.5))); } else { ASSERT_TRUE(false); } @@ -1591,13 +1763,14 @@ void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(4); - test->BuildAndRun({{"input", input}}, "my_binary", &output_data); + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; + test->BuildAndRun(input_data, &output_data); if (weights_dims.size() == 1) { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(11), CType(22), CType(13), CType(24))); } else { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(11), CType(12), CType(23), CType(24))); } } @@ -1625,9 +1798,10 @@ void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); - std::vector output_data(4); - test->BuildAndRun({{"input", input}}, "my_binary", &output_data); - EXPECT_THAT(output_data, + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; + test->BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(11), CType(12), CType(13), CType(14))); } @@ -1675,17 +1849,19 @@ void TestBinaryTensorOpWeightFallback(OpConverterTest* test, // Check the result of running the engine. const int expected_num_outputs = TrtDimsNumElements(GetTestDims(expected_output_dims)); - std::vector output_data(expected_num_outputs); - test->BuildAndRun( - {{"input", - /*input_data=*/std::vector(num_inputs, CType(2))}}, - "my_binary", &output_data); + const DataVec input_data{ + {"input", ConstructTensor(num_inputs, CType(2))}}; + DataVec output_data{ + {"my_binary", ConstructTensor(expected_num_outputs)}}; + test->BuildAndRun(input_data, &output_data); if (node_def.op() == "Add") { - EXPECT_THAT(output_data, ElementsAreArray(std::vector( - expected_num_outputs, CType(3)))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(std::vector(expected_num_outputs, CType(3)))); } else if (node_def.op() == "Minimum") { - EXPECT_THAT(output_data, ElementsAreArray(std::vector( - expected_num_outputs, CType(1)))); + EXPECT_THAT( + GetSpanForData(output_data[0]), + ElementsAreArray(std::vector(expected_num_outputs, CType(1)))); } else { ASSERT_TRUE(false); } @@ -1712,32 +1888,33 @@ void TestBinaryTensorOpTensor(OpConverterTest* test) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); - std::vector output_data(4); + const DataVec input_data{ + {"input1", test::AsTensor({CType(3), CType(6)})}, + {"input2", test::AsTensor({CType(2), CType(3)})}}; + DataVec output_data{{"my_binary", ConstructTensor(4)}}; // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. - test->BuildAndRun( - {{"input1", {CType(3), CType(6)}}, {"input2", {CType(2), CType(3)}}}, - "my_binary", &output_data); + test->BuildAndRun(input_data, &output_data); if (node_def.op() == "Add") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(5), CType(8), CType(6), CType(9))); } else if (node_def.op() == "Sub") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(1), CType(4), CType(0), CType(3))); } else if (node_def.op() == "Mul") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(6), CType(12), CType(9), CType(18))); } else if (node_def.op() == "Div") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); } else if (node_def.op() == "RealDiv") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); } else if (node_def.op() == "Minimum") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(2), CType(2), CType(3), CType(3))); } else if (node_def.op() == "Maximum") { - EXPECT_THAT(output_data, + EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(CType(3), CType(6), CType(3), CType(6))); } else { ASSERT_TRUE(false); @@ -1751,7 +1928,9 @@ TEST_F(OpConverterTest, ConvertBinary) { NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"}); AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Binary ops require two inputs, at my_add"); + StrCat("Add got ", std::to_string(num_inputs), + " inputs but expected 2, at my_add") + .c_str()); } { // Both inputs are weights. @@ -1821,14 +2000,18 @@ TEST_F(OpConverterTest, ConvertBinary) { } TEST_F(OpConverterTest, ConvertQuantize) { - for (const string& op : - {"FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars", - "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"}) { + const std::pair op_with_num_inputs[4] = { + {"FakeQuantWithMinMaxArgs", 1}, + {"FakeQuantWithMinMaxVars", 3}, + {"QuantizeAndDequantizeV2", 3}, + {"QuantizeAndDequantizeV3", 4}}; + for (const auto& pair : op_with_num_inputs) { // Input list is empty, should fail. - NodeDef node_def = MakeNodeDef("my_quantize", op, {}); + NodeDef node_def = MakeNodeDef("my_quantize", pair.first, {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - StrCat("Invalid number of inputs for ", op, ", at my_quantize") + StrCat(pair.first, " got 0 inputs but expected ", + std::to_string(pair.second), ", at my_quantize") .c_str()); } { @@ -1915,9 +2098,9 @@ TEST_F(OpConverterTest, ConvertQuantize) { AddTestTensor("weights_min", {1}); AddTestTensor("weights_max", {1}); RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Min and max inputs for QuantizeAndDequantizeV2 must be weights not " - "tensors, at my_quantize"); + node_def, error::UNIMPLEMENTED, + "The input \"input_min\" for QuantizeAndDequantizeV2 must be a constant" + ", at my_quantize"); } { // QuantizeAndDequantizeV3 ranges set via inputs, ok. @@ -1950,7 +2133,7 @@ TEST_F(OpConverterTest, ConvertRelu6) { NodeDef node_def = MakeNodeDef("my_relu6", "Relu6", {}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Invalid number of inputs for Relu6, at my_relu6"); + "Relu6 got 0 inputs but expected 1, at my_relu6"); } // Get the NodeDef for Relu6. @@ -1964,7 +2147,7 @@ TEST_F(OpConverterTest, ConvertRelu6) { AddTestWeights("input", {1}, {1.0f}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Relu6 is only implemented for tensors, not weights, at my_relu6"); + "The input \"input\" for Relu6 must be a tensor, at my_relu6"); } { // Clip tensor values and set quantization ranges, ok. @@ -1977,10 +2160,12 @@ TEST_F(OpConverterTest, ConvertRelu6) { auto ranges = quantization_ranges(); EXPECT_EQ(ranges[output.tensor()], 6.0f); - std::vector output_data(6); - BuildAndRun({{"input", {-100, -1, 0, 3, 5, 9}}}, "my_relu6", - &output_data); - EXPECT_THAT(output_data, ElementsAre(0, 0, 0, 3, 5, 6)); + const DataVec input_data{ + {"input", test::AsTensor({-100, -1, 0, 3, 5, 9})}}; + DataVec output_data{{"my_relu6", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(0, 0, 0, 3, 5, 6)); } } @@ -2002,24 +2187,26 @@ void TestConvertSquare(OpConverterTest* test) { ExpectTrtDimsEqualsArray({1, 20}, output.tensor()->getDimensions()); const int num_inputs = 20; - std::vector input_data(num_inputs); - std::vector expected_output_data(num_inputs); + std::vector inputs(num_inputs); + std::vector expected_outputs(num_inputs); for (int i = 0; i < 20; i++) { const CType value = CType(i - 9); - input_data[i] = value; - expected_output_data[i] = value * value; + inputs[i] = value; + expected_outputs[i] = value * value; } - std::vector output_data(num_inputs); - test->BuildAndRun({{"input", input_data}}, "my_square", &output_data); - ExpectArrayNear(expected_output_data, output_data); + const DataVec input_data{{"input", test::AsTensor(inputs)}}; + DataVec output_data{{"my_square", ConstructTensor(num_inputs)}}; + test->BuildAndRun(input_data, &output_data); + ExpectArrayNear(expected_outputs, GetSpanForData(output_data[0])); } TEST_F(OpConverterTest, ConvertSquare) { { // Input list is empty, should fail. NodeDef node_def = MakeNodeDef("my_square", "Square", {}); - RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Square expects one input, at my_square"); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Square got 0 inputs but expected 1, at my_square"); } { // Input is weights, should fail. @@ -2031,7 +2218,7 @@ TEST_F(OpConverterTest, ConvertSquare) { AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Square is only implemented for tensors, at my_square"); + "The input \"x\" for Square must be a tensor, at my_square"); } // OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't @@ -2047,7 +2234,7 @@ TEST_F(OpConverterTest, ConvertActivation) { // Input list is empty, should fail. NodeDef node_def = MakeNodeDef("my_act", "Relu", {}); RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Relu expects one input, at my_act"); + "Relu got 0 inputs but expected 1, at my_act"); } { // Input is weights, should fail. @@ -2059,7 +2246,7 @@ TEST_F(OpConverterTest, ConvertActivation) { AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "Relu is only implemented for tensors, at my_act"); + "The input \"input\" for Relu must be a tensor, at my_act"); } // Get nodedef for activation layer. @@ -2103,12 +2290,875 @@ TEST_F(OpConverterTest, ConvertActivation) { EXPECT_TRUE(output.is_tensor()); ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); - const std::vector input_data = {-100, -2, -1, 0, 1, 100}; - std::vector output_data(6); - BuildAndRun({{"input", input_data}}, "my_act", &output_data); - for (int i = 0; i < input_data.size(); i++) { - const float expected_output = get_act_output(op_name, input_data[i]); - EXPECT_FLOAT_EQ(output_data[i], expected_output); + const std::vector input = {-100, -2, -1, 0, 1, 100}; + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_act", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + for (int i = 0; i < input.size(); i++) { + const float expected_output = get_act_output(op_name, input[i]); + EXPECT_FLOAT_EQ(GetSpanForData(output_data[0])[i], + expected_output); + } + } +} + +TEST_F(OpConverterTest, ConvertExpandDims) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_expanddims", "ExpandDims", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "ExpandDims got 0 inputs but expected 2, at my_expanddims"); + } + + // Get the NodeDef for ExpandDims. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto expanddims = + ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights); + const NodeDef& node_def = expanddims.operation.node()->def(); + { + // Input is weights, should fail. + Reset(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("weights", {1}, {1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"input\" for ExpandDims must be a " + "tensor, at my_expanddims"); + } + { + // Axis is a tensor, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights", {3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"axis\" for ExpandDims must be a " + "constant, at my_expanddims"); + } + { + // Add dim at batch dimension, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {1}, {0}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Modifying batch dimension is not supported for ExpandDims, at " + "my_expanddims"); + } + { + // Add dim at batch dimension via negative axis, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {-5}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Modifying batch dimension is not supported for ExpandDims, at " + "my_expanddims"); + } + { + // Axis > rank(input), should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {5}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at my_expanddims"); + } + { + // Axis < -rank(input)-1, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {-6}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at my_expanddims"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, int axis, + const std::vector& expected_output_dims) + : input_dims(input_dims), + axis(axis), + expected_output_dims(expected_output_dims) {} + std::vector input_dims; + int axis; + std::vector expected_output_dims; + }; + + // Ok. + const int kExpandDimsOKCases = 8; + TestParams ok_params[kExpandDimsOKCases] = { + TestParams{{2, 3}, 1, {1, 2, 3}}, TestParams{{2, 3}, -3, {1, 2, 3}}, + TestParams{{2, 3}, 3, {2, 3, 1}}, TestParams{{2, 3}, -1, {2, 3, 1}}, + TestParams{{2, 3}, 2, {2, 1, 3}}, TestParams{{2, 3}, -2, {2, 1, 3}}, + TestParams{{6}, 1, {1, 6}}, TestParams{{6}, -1, {6, 1}}, + }; + for (int i = 0; i < kExpandDimsOKCases; ++i) { + Reset(); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", {1}, {ok_params[i].axis}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_expanddims", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_expanddims", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 2, 3, 4, 5, 6)); + } +} + +TEST_F(OpConverterTest, ConvertSqueeze) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_squeeze", "Squeeze", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Squeeze got 0 inputs but expected 1, at my_squeeze"); + } + { + // No attrs, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input); + const NodeDef& node_def = squeeze.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Squeeze is only implemented for explicit dims, at my_squeeze"); + } + + // Get the NodeDef for Squeeze. + auto get_squeeze_nodedef = [](std::vector axis) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + ops::Squeeze::Attrs squeeze_attrs; + squeeze_attrs.axis_ = gtl::ArraySlice(axis); // non-absl ok + auto squeeze = + ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); + return squeeze.operation.node()->def(); + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({0}); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"input\" for Squeeze must be a tensor, at my_squeeze"); + } + { + // Squeeze batch dim, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({0}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Cannot squeeze batch dimension, at my_squeeze"); + } + { + // Squeeze batch dim via negative axis, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({-4}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Cannot squeeze batch dimension, at my_squeeze"); + } + { + // Squeeze >= rank(input), should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({4}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at my_squeeze"); + } + { + // Squeeze < -rank(input), should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({-5}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at my_squeeze"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, const std::vector& axis, + const std::vector& expected_output_dims) + : input_dims(input_dims), + axis(axis), + expected_output_dims(expected_output_dims) {} + std::vector input_dims; + std::vector axis; + std::vector expected_output_dims; + }; + + // Ok. + const int kSqueezeOKCases = 10; + TestParams ok_params[kSqueezeOKCases] = { + TestParams{{1, 2, 3}, {1}, {2, 3}}, + TestParams{{1, 2, 3}, {-3}, {2, 3}}, + TestParams{{2, 3, 1}, {3}, {2, 3}}, + TestParams{{2, 3, 1}, {-1}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {1, 3, 5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {3, 1, 5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {-1, -3, -5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {1, -3, 5}, {2, 3}}, + TestParams{{1, 6}, {1}, {6}}, + TestParams{{6, 1}, {2}, {6}}, + }; + for (int i = 0; i < kSqueezeOKCases; ++i) { + Reset(); + NodeDef node_def = get_squeeze_nodedef(ok_params[i].axis); + AddTestTensor("input", ok_params[i].input_dims); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_squeeze", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_squeeze", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(1, 2, 3, 4, 5, 6)); + } +} + +TEST_F(OpConverterTest, ConvertStridedSlice) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "StridedSlice got 0 inputs but expected 4, at my_strided_slice"); + } + + // Get nodedef for StridedSlice layer. + auto get_strided_slice_nodedef = + [](int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, + int new_axis_mask = 0, int shrink_axis_mask = 0) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32); + auto end = ops::Placeholder(s.WithOpName("end"), DT_INT32); + auto strides = ops::Placeholder(s.WithOpName("strides"), DT_INT32); + ops::StridedSlice::Attrs attrs = ops::StridedSlice::Attrs() + .BeginMask(begin_mask) + .EndMask(end_mask) + .EllipsisMask(ellipsis_mask) + .NewAxisMask(new_axis_mask) + .ShrinkAxisMask(shrink_axis_mask); + auto strided_slice = ops::StridedSlice(s.WithOpName("my_strided_slice"), + input, begin, end, strides, attrs); + return strided_slice.operation.node()->def(); + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"input\" for StridedSlice must be a " + "tensor, at my_strided_slice"); + } + { + // Begin, end, strides are tensors, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("begin", {4}); + AddTestTensor("end", {4}); + AddTestTensor("strides", {4}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"begin\" for StridedSlice must be a constant, at " + "my_strided_slice"); + } + { + // Non-zero ellipsis_mask, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef( + /*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/2, + /*new_axis_mask=*/0, /*shrink_axis_mask=*/0); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "ellipsis_mask is not supported for StridedSlice, at " + "my_strided_slice"); + } + { + // Modify batch dim, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {0, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "StridedSlice can't modify batch dim, at my_strided_slice"); + } + { + // Stride is not 1, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 2, -1, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "StridedSlice is only implemented for stride of " + "1, at my_strided_slice"); + } + { + // Begin out of bounds, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {1, 2, 3, 4}); + AddTestWeights("end", {4}, {0, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "begin value of 2 for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at my_strided_slice"); + } + { + // End out of bounds, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 2, 3, 4}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "end value of 2 for StridedSlice is invalid, must be in the range " + "[-dim_size(i), dim_size(i)], at my_strided_slice"); + } + { + // Size of sliced dim is negative, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 2, 0}); + AddTestWeights("end", {4}, {1, 1, 0, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "New size of sliced dimension is negative, at my_strided_slice"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, + const std::vector& expected_output_dims, + const std::vector& begin, const std::vector& end, + const std::vector& begin_mask, + const std::vector& end_mask, + const std::vector& expected_output) + : input_dims(input_dims), + expected_output_dims(expected_output_dims), + begin(begin), + end(end), + expected_output(expected_output) { + // Masks are provided in terms of vectors for readability. Convert them to + // binary here. + this->begin_mask = 0; + for (int i = 0; i < begin_mask.size(); i++) { + if (begin_mask[i]) this->begin_mask |= (1 << i); + } + this->end_mask = 0; + for (int i = 0; i < end_mask.size(); i++) { + if (end_mask[i]) this->end_mask |= (1 << i); + } + } + + std::vector input_dims; + std::vector expected_output_dims; + std::vector begin; + std::vector end; + int begin_mask; + int end_mask; + std::vector expected_output; + }; + + // Ok. + const int kStridedSliceOKCases = 18; + TestParams ok_params[kStridedSliceOKCases] = { + // 2D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 1, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{5, 6}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with transpose. + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, + /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{5, 6}}, + TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with reshape. + TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, + /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, + /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 1}, + /*expected_output=*/{5, 6}}, + // 1D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 2, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 0}, + /*expected_output=*/{1, 2, 4, 5}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 3}, + /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with transpose. + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 1, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, + /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with reshape. + TestParams{/*input_dims=*/{6}, /*expected_output_dims=*/{3}, + /*begin=*/{0, 0}, /*end=*/{0, 3}, + /*begin_mask=*/{0, 0}, /*end_mask=*/{1, 0}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{1, 6}, /*expected_output_dims=*/{1, 3}, + /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 0}, + /*expected_output=*/{3, 4, 5}}, + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, + /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{3, 4, 5}}, + // Negative axis. + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, + /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{5, 1}, + /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{1, 2, 3, 4, 5}}, + }; + + for (int i = 0; i < kStridedSliceOKCases; i++) { + Reset(); + NodeDef node_def = get_strided_slice_nodedef(ok_params[i].begin_mask, + ok_params[i].end_mask); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("begin", + {static_cast(ok_params[i].begin.size())}, + ok_params[i].begin); + AddTestWeights("end", {static_cast(ok_params[i].end.size())}, + ok_params[i].end); + std::vector strides(ok_params[i].input_dims.size(), 1); + AddTestWeights("strides", {static_cast(strides.size())}, + strides); + RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output)); + + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{ + {"my_strided_slice", + ConstructTensor(ok_params[i].expected_output.size())}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertConv2D) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_conv2d", "Conv2D", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Conv2D got 0 inputs but expected 2, at my_conv2d"); + } + + // Get nodedef for Conv2D layer. + auto get_conv2d_nodedef = + [](std::vector strides = {1, 1, 1, 1}, string padding = "SAME", + string data_format = "NCHW", std::vector dilations = {1, 1, 1, 1}, + bool is_conv2d_backprop_input = false) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT); + if (is_conv2d_backprop_input) { + auto input_sizes = + ops::Placeholder(s.WithOpName("input_sizes"), DT_INT32); + ops::Conv2DBackpropInput::Attrs attrs = ops::Conv2DBackpropInput::Attrs() + .DataFormat(data_format) + .Dilations(dilations); + auto conv2d = + ops::Conv2DBackpropInput(s.WithOpName("my_conv2d"), input_sizes, + filter, input, strides, padding, attrs); + return conv2d.operation.node()->def(); + } else { + ops::Conv2D::Attrs attrs = + ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); + auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, + strides, padding, attrs); + return conv2d.operation.node()->def(); + } + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"input\" for Conv2D must be a tensor, at my_conv2d"); + } + { + // Filter is tensor, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights", {3, 3, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"filter\" for Conv2D must be a constant, at my_conv2d"); + } + { + // Filter is not 4D, should fail. + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Conv2D expects kernel of dimension 4, at my_conv2d"); + } + { + // Dilations is not 4D, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Convolution dilations field must specify 4 dimensions, at my_conv2d"); + } + { + // Dilation value is not 1 for channel, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation rate must be 1 for batch and channel " + "dimensions, at my_conv2d"); + } + { + // Dilation value is not 1 for channel (NHWC), should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2}); + AddTestTensor("input", {2, 3, 1}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation rate must be 1 for batch and channel " + "dimensions, at my_conv2d"); + } + { + // Dilation + Conv2DBackpropInput, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 2, 1}, true); + AddTestTensor("input", {2, 3, 1}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + AddTestWeights("input_sizes", {4}, {1, 2, 3, 1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation with Conv2DBackpropInput " + "(conv2d_transpose) is not supported, " + "at my_conv2d"); + } + { + // Strides is not 4D, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Convolution strides field must specify 4 dimensions, at my_conv2d"); + } + { + // Stride value is not 1 for channel, should fail. + Reset(); + NodeDef node_def = + get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Stride must be 1 for batch and channel dimensions, at my_conv2d"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, + const std::vector& input, + const std::vector& filter_dims, + const std::vector& filter, + const std::vector& strides, const string& padding, + const string& data_format, const std::vector& dilations, + bool is_conv2d_backprop_input, + const std::vector& expected_output_dims, + const std::vector& expected_output) + : input_dims(input_dims), + input(input), + filter_dims(filter_dims), + filter(filter), + strides(strides), + padding(padding), + data_format(data_format), + dilations(dilations), + is_conv2d_backprop_input(is_conv2d_backprop_input), + expected_output_dims(expected_output_dims), + expected_output(expected_output) {} + + std::vector input_dims; + std::vector input; + std::vector filter_dims; + std::vector filter; + std::vector strides; + string padding; + string data_format; + std::vector dilations; + bool is_conv2d_backprop_input; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Ok. + const int kConv2DOKCases = 7; + TestParams ok_params[kConv2DOKCases] = { + // Basic + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 1, 0, 1}}, + // SAME padding (Asymmetric) + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 1, -2, 0, 1, -4}}, + // SAME padding (Symmetric) + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 3, 1, 1}, + /*filter=*/{-1, 0, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 2, -1, 3, 1, -3}}, + // NHWC + TestParams{/*input_dims=*/{2, 3, 1}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NHWC", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{2, 2, 1}, + /*expected_output=*/{1, 1, 0, 1}}, + // Dilated + TestParams{/*input_dims=*/{1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 2}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{2, 1}}, + // Strided + TestParams{/*input_dims=*/{1, 2, 4}, + /*input=*/{0, 1, 2, 2, 3, 4, 4, 7}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 2}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 0, 1, 3}}, + // Transpose Strided + TestParams{/*input_dims=*/{1, 2, 2}, + /*input=*/{0, 1, 2, 3}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 2}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*is_conv2d_backprop_input=*/true, + /*expected_output_dims=*/{1, 2, 4}, + /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}}, + }; + + for (int i = 0; i < kConv2DOKCases; i++) { + Reset(); + NodeDef node_def = get_conv2d_nodedef( + ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, + ok_params[i].dilations, ok_params[i].is_conv2d_backprop_input); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", ok_params[i].filter_dims, + ok_params[i].filter); + if (ok_params[i].is_conv2d_backprop_input) { + AddTestWeights( + "input_sizes", + {static_cast(ok_params[i].expected_output.size())}, + ok_params[i].expected_output); + } + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{ + {"input", test::AsTensor(ok_params[i].input)}}; + DataVec output_data{ + {"my_conv2d", + ConstructTensor(ok_params[i].expected_output.size())}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertTopK) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_topk", "TopKV2", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Input expects tensor and weights, at my_topk"); + } + + for (const auto dtype : {DT_FLOAT, DT_INT32}) { + // Get the NodeDef for TopKV2. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto topk = ops::TopK(s.WithOpName("my_topk"), input, weights); + const NodeDef& node_def = topk.operation.node()->def(); + { + // K is a tensor, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1, + /*trt_dtype=*/TfDataTypeToTrt(dtype)); + AddTestTensor("weights", {2}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Input expects tensor and weights, at my_topk"); + } + { + // Ok. + Reset(); + AddTestTensor("input", {1, 2, 5}); + AddTestWeights("weights", {1}, {2}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights outputs[2]; + TF_EXPECT_OK(GetTensorOrWeights("my_topk", &outputs[0])); + TF_EXPECT_OK(GetTensorOrWeights("my_topk:1", &outputs[1])); + for (auto& output : outputs) { + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 2, 2}, output.tensor()->getDimensions()); + } + + const DataVec input_data{ + {"input", test::AsTensor({-9, 3, 5, 1, 6, -5, 7, 1, 0, -1})}}; + DataVec output_data{{"my_topk", ConstructTensor(4)}, + {"my_topk:1", ConstructTensor(4)}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAre(6, 5, 7, 1)); + EXPECT_THAT(GetSpanForData(output_data[1]), + ElementsAre(4, 2, 1, 2)); } } } diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc similarity index 94% rename from tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc rename to tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index c1688d4db88a270dcd202989f89a677ed10576d9..f36aa558ea2ea463983caf163e17f83ae1c38f40 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -12,9 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h" -#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" @@ -30,9 +32,9 @@ namespace tensorflow { namespace tensorrt { namespace convert { // TODO(sami): Remove VLOG messages once the code matures +using absl::StrAppend; +using absl::StrCat; using tensorflow::str_util::Uppercase; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; tensorflow::Status TRTOptimizationPass::Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) { @@ -64,7 +66,7 @@ tensorflow::Status TRTOptimizationPass::Init( max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i(); } if (params.count("precision_mode")) { - TF_RETURN_IF_ERROR(GetPrecisionMode( + TF_RETURN_IF_ERROR(TrtPrecisionModeFromName( Uppercase(params.at("precision_mode").s()), &precision_mode_)); } if (params.count("use_calibration")) { @@ -225,9 +227,10 @@ tensorflow::Status TRTOptimizationPass::Optimize( TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); tensorflow::tensorrt::convert::ConversionParams cp; - if (use_calibration_ && precision_mode_ != INT8MODE) { - LOG(ERROR) << "Calibration with FP32 or FP16 is not implemented. " - << "Falling back to use_calibration = False."; + if (use_calibration_ && precision_mode_ != TrtPrecisionMode::INT8) { + VLOG(1) << "Calibration with FP32 or FP16 is not implemented. " + << "Falling back to use_calibration = False." + << "Note that the default value of use_calibration is True."; use_calibration_ = false; } @@ -242,7 +245,7 @@ tensorflow::Status TRTOptimizationPass::Optimize( // If the last token is not an integer, it must be part of the name. // Otherwise it is port number. if (tokens.size() > 1 && - !strings::safe_strto32(tokens.back(), &dumm_port)) { + !strings::safe_strto32(tokens.back(), &dumm_port)) { // non-absl ok StrAppend(&s, ":", tokens.back()); } nodes_to_preserve.push_back(s); diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h similarity index 87% rename from tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h rename to tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h index 3e8dc0978e43e2e9ba07aaa09f74acfe8e59b9a7..b2aed2a37afb6c01863f5617bad0bafe004eec24 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ #include +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" #include "tensorflow/core/platform/logging.h" @@ -34,7 +35,7 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { TRTOptimizationPass(const string& name = "TRTOptimizationPass") : name_(name), minimum_segment_size_(3), - precision_mode_(0), + precision_mode_(TrtPrecisionMode::FP32), maximum_batch_size_(-1), is_dynamic_op_(false), max_cached_batches_(1), @@ -62,7 +63,7 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { private: const string name_; int minimum_segment_size_; - int precision_mode_; + TrtPrecisionMode precision_mode_; int maximum_batch_size_; bool is_dynamic_op_; std::vector batches_; @@ -77,4 +78,4 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { #endif // GOOGLE_CUDA #endif // GOOGLE_TENSORRT -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc similarity index 73% rename from tensorflow/contrib/tensorrt/convert/utils.cc rename to tensorflow/compiler/tf2tensorrt/convert/utils.cc index e7a1febb8c076891596741fe30721e7acca15a73..0ca3a5a4a58e6a3e29d3d515f496b8cb5e9f7eb0 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -34,33 +34,32 @@ bool IsGoogleTensorRTEnabled() { #endif } -Status GetPrecisionModeName(const int precision_mode, string* name) { - switch (precision_mode) { - case FP32MODE: +Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name) { + switch (mode) { + case TrtPrecisionMode::FP32: *name = "FP32"; break; - case FP16MODE: + case TrtPrecisionMode::FP16: *name = "FP16"; break; - case INT8MODE: + case TrtPrecisionMode::INT8: *name = "INT8"; break; default: - return tensorflow::errors::OutOfRange("Unknown precision mode"); + return errors::OutOfRange("Unknown precision mode"); } return Status::OK(); } -Status GetPrecisionMode(const string& name, int* precision_mode) { +Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) { if (name == "FP32") { - *precision_mode = FP32MODE; + *mode = TrtPrecisionMode::FP32; } else if (name == "FP16") { - *precision_mode = FP16MODE; + *mode = TrtPrecisionMode::FP16; } else if (name == "INT8") { - *precision_mode = INT8MODE; + *mode = TrtPrecisionMode::INT8; } else { - return tensorflow::errors::InvalidArgument("Invalid precision mode name: ", - name); + return errors::InvalidArgument("Invalid precision mode name: ", name); } return Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h similarity index 72% rename from tensorflow/contrib/tensorrt/convert/utils.h rename to tensorflow/compiler/tf2tensorrt/convert/utils.h index 0592f31462af2b20f3a13fe5119e89c2ba42dd8a..0aa602dda2f3e98095bf72b5810a246c690d6741 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ #include @@ -35,16 +35,13 @@ using TrtUniquePtrType = std::unique_ptr>; bool IsGoogleTensorRTEnabled(); -// TODO(aaroey): use an enum instead. -const int FP32MODE = 0; -const int FP16MODE = 1; -const int INT8MODE = 2; +enum class TrtPrecisionMode { FP32, FP16, INT8 }; -Status GetPrecisionModeName(const int precision_mode, string* name); +Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name); -Status GetPrecisionMode(const string& name, int* precision_mode); +Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode); } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..eae1f8e7525f1816d1c50072ebe4ba6713c96e47 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ + +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/refcount.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +class GetSerializedResourceOp : public OpKernel { + public: + explicit GetSerializedResourceOp(OpKernelConstruction* context) + : OpKernel(context) {} + + ~GetSerializedResourceOp() override {} + + void Compute(OpKernelContext* context) override { + // TODO(laigd): it will allocate the tensor on the device and copy the + // serialized string to that tensor, and later sess.run() will copy it back + // to host. We need to optimize this. + const string& container = context->input(0).scalar()(); + const string& resource_name = context->input(1).scalar()(); + + // Get the resource. + SerializableResourceBase* resource = nullptr; + OP_REQUIRES_OK(context, context->resource_manager()->Lookup( + container, resource_name, &resource)); + ::tensorflow::core::ScopedUnref sc(resource); + + // Serialize the resource as output. + string serialized_resource; + OP_REQUIRES_OK(context, resource->SerializeToString(&serialized_resource)); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = serialized_resource; + } +}; + +REGISTER_KERNEL_BUILDER(Name("GetSerializedResourceOp").Device(DEVICE_GPU), + GetSerializedResourceOp); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec038ebda073c8050321d5668b15a2c6faa72a4b --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc @@ -0,0 +1,80 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +class GetSerializedResourceOpTest : public OpsTestBase {}; + +TEST_F(GetSerializedResourceOpTest, Basic) { + // Create the GPU device. + std::unique_ptr device( + DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0")); + + // Create the resource. + class MySerializableResource : public SerializableResourceBase { + public: + string DebugString() const override { return ""; } + Status SerializeToString(string* serialized) override { + *serialized = "my_serialized_str"; + return Status::OK(); + } + }; + const string container = "mycontainer"; + const string resource_name = "myresource"; + SerializableResourceBase* resource = new MySerializableResource(); + ResourceMgr* rm = device->resource_manager(); + EXPECT_TRUE(rm->Create(container, resource_name, resource).ok()); + + // Create the op. + SetDevice(DEVICE_GPU, std::move(device)); + TF_ASSERT_OK(NodeDefBuilder("op", "GetSerializedResourceOp") + .Input(FakeInput(DT_STRING)) + .Input(FakeInput(DT_STRING)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + + // Execute the op. + AddInputFromArray(TensorShape({}), {container}); + AddInputFromArray(TensorShape({}), {resource_name}); + TF_ASSERT_OK(RunOpKernel()); + + // Verify the result. + // TODO(laigd): OpsTestBase::GetOutput() doesn't work. + Tensor* output = context_->mutable_output(0); + EXPECT_EQ("my_serialized_str", output->scalar()()); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc similarity index 59% rename from tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc rename to tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index bad568644bb1f8d01d4cb0a7c853ec47d6f19e45..bc5335ef5aa35633a68e69f7de7903b4f498531a 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -12,35 +12,45 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" - #include - -#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT #include "cuda/include/cuda_runtime_api.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace tensorrt { static Logger logger; +using absl::StrAppend; +using absl::StrCat; using ::nvinfer1::IRuntime; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; // A helper class to call done() when destructed for asynchronous execution. // Helps simultaneous execution of native and TRT engines. @@ -53,6 +63,83 @@ class AsyncHelper : public tensorflow::core::RefCounted { AsyncOpKernel::DoneCallback done_; }; +// This OP can construct TRTEngine on the fly and if construction of engine +// fails, executes equivalent subgraph as a TensorFlow function. +class TRTEngineOp : public AsyncOpKernel { + public: + explicit TRTEngineOp(OpKernelConstruction* context); + + void ComputeAsync(OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; + + private: + // Execute calibration + void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); + + // Construct a function handle for executing native funcdef graph + Status ConstructFunctionHandle(OpKernelContext* ctx); + + // Execute replaced native segment as function Op. + void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); + + // Execute the tensorrt engine. Returns whether we need to retry by running + // the native segment. + bool ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context); + + // Allocate necessary resources for calibration + Status AllocateCalibrationResources(OpKernelContext* ctx, + SerializableResourceBase** cr); + + // Get engine for the input shape + EngineContext* GetEngine(const std::vector& input_shapes, + OpKernelContext* ctx); + + // Return engine batch in cached_engne_batch_sizes_ which is closest to input + // batch. + bool GetCompatibleCachedEngine( + const std::vector& actual_input_shapes, + std::vector* engine_input_shapes); + + std::vector input_nodes_; + std::vector output_nodes_; + + // serialized protobuf segment or trt engine depending on static_engine_ flag. + string serialized_segment_; + + // Name of the function for TF native execution of the segment. + string funcdef_name_; + + // GraphDef representation of the segment. + GraphDef segment_graph_; + + // Engine Precision mode. + TrtPrecisionMode precision_mode_; + + // Whether engine is constructed during the conversion or needs to be + // constructed from protobuf segment. + bool static_engine_; + + // Whether to calibrate INT8 engine. + bool calibration_mode_; + + // Batches of the cached engines + std::vector cached_engine_batches_; + + // Maximum number of cached engines + int max_cached_engines_; + + int64 workspace_size_; + mutex engine_mutex_; + FunctionLibraryRuntime::Handle native_func_; + + // The finalized calibrator for inference. + std::unique_ptr calibrator_; + + // If true, create calibration graph for INT8 mode. Otherwise, we are using + // user-provided quantization ranges. + bool use_calibration_; +}; + #define TYPECASE(dt, X, Y) \ case dt: { \ return (void*)X->flat::Type>().data(); \ @@ -123,11 +210,13 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) context->GetAttr("calibration_data", &calibration_data)); OP_REQUIRES_OK(context, context->GetAttr("segment_funcdef_name", &funcdef_name_)); - OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_)); + OP_REQUIRES_OK(context, + TrtPrecisionModeFromName(precision_string, &precision_mode_)); OP_REQUIRES_OK(context, context->GetAttr("use_calibration", &use_calibration_)); - calibration_mode_ = (use_calibration_ && precision_mode_ == INT8MODE && - calibration_data.size() == 0); + calibration_mode_ = + (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 && + calibration_data.size() == 0); if (calibration_data.size()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data)); calibration_data.resize(0); @@ -135,8 +224,6 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) native_func_ = tensorflow::kInvalidHandle; OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", &max_cached_engines_)); - OP_REQUIRES_OK(context, - context->GetAttr("fixed_input_size", &fixed_input_size_)); OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches", &cached_engine_batches_)); std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end()); @@ -175,11 +262,13 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, lib->Run(opts, native_func_, inputs, outputs, [this, ctx, outputs, helper](const tensorflow::Status& s) { tensorflow::core::ScopedUnref sc(helper); - VLOG(1) << "Native Segment completed"; if (!s.ok()) { + LOG(ERROR) << "Failed to execute native segment " << this->name() + << ": " << s; ctx->SetStatus(s); return; } + VLOG(1) << "Native Segment completed"; for (size_t t = 0; t < outputs->size(); ++t) { ctx->set_output(t, outputs->at(t)); } @@ -194,18 +283,37 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, VLOG(1) << "Executing TRT calibration: " << name(); helper->Ref(); tensorflow::core::ScopedUnref sc(helper); - // TODO(aaroey): remove the ResourceMgr singleton. - auto trt_rm = TRTResourceManager::instance(); - auto res_mgr = trt_rm->getManager("TRTCalibration"); + auto res_mgr = ctx->resource_manager(); TRTCalibrationResource* calib_res = nullptr; - auto status = res_mgr->LookupOrCreate( - funcdef_name_, "Calibrator", &calib_res, - {[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status { - return this->AllocateCalibrationResources(ctx, cr); - }}); - if (!status.ok()) { - ctx->SetStatus(status); - return; + OP_REQUIRES_OK( + ctx, + res_mgr->LookupOrCreate( + "TF_TRT_Calibration", name(), + reinterpret_cast(&calib_res), + {[ctx, this](SerializableResourceBase** cr) -> tensorflow::Status { + return this->AllocateCalibrationResources(ctx, cr); + }})); + tensorflow::core::ScopedUnref calib_sc(calib_res); + // TODO(aaroey): here we also add the resource to the ResourceMgr singleton. + // This is needed before we migrate all uses of calib_graph_to_infer_graph() + // to the new calibration workflow. After that we'll remove this block. + { + auto deprecated_rm = + TRTResourceManager::instance()->getManager("TRTCalibration"); + TRTCalibrationResource* copied_resource = nullptr; + // Check whether the resource exists, and create it if not. + if (deprecated_rm->Lookup(funcdef_name_, "Calibrator", &copied_resource) + .ok()) { + // Do nothing if the resource exists. + copied_resource->Unref(); + } else { + copied_resource = calib_res; + // Increase the refcount by 1 then transfer the ownership of that refcount + // to the ResourceMgr singleton. + copied_resource->Ref(); + OP_REQUIRES_OK(ctx, deprecated_rm->Create(funcdef_name_, "Calibrator", + copied_resource)); + } } int num_inputs = ctx->num_inputs(); // Pass input data to calibrator @@ -219,7 +327,8 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, return; } // Check the allocated buffer is sufficient for input - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + const auto device_tensor = + calib_res->device_tensors_.at(i).AccessTensor(ctx); CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); input_data.emplace(StrCat(kInputPHName, i), data_address); } @@ -236,32 +345,34 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, ExecuteNativeSegment(ctx, helper); } -int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) { - int num_batch = ctx->input(0).shape().dim_size(0); - int smallest_engine = 0; - for (const auto i : cached_engine_batches_) { - if (i >= num_batch) { - smallest_engine = i; - break; - } - } - // TODO(sami): Need an LRU here - if (smallest_engine == 0) { - if (max_cached_engines_ > cached_engine_batches_.size()) { - smallest_engine = num_batch; - cached_engine_batches_.push_back(num_batch); - VLOG(1) << "Running with batch size " << num_batch; - } else { - string msg = - StrCat("Engine buffer is full. buffer limit=", max_cached_engines_, - ", current entries="); - for (auto i : cached_engine_batches_) StrAppend(&msg, i, ","); - StrAppend(&msg, " requested batch=", num_batch); - LOG(WARNING) << msg; - return -1; +bool TRTEngineOp::GetCompatibleCachedEngine( + const std::vector& actual_input_shapes, + std::vector* engine_input_shapes) { + const int batch_size = actual_input_shapes[0].dim_size(0); + int smallest_batch_size = -1; + // Output shape will always be the same as the input but we will overwrite the + // batch size. + *engine_input_shapes = actual_input_shapes; + for (const int cached_batch_size : cached_engine_batches_) { + // Check if compatible: batch <= cached batch. + // + // TODO(laigd): here it only compare the first dim a.k.a the batch size, + // we'll need to to support non-batch dimensions as well. This will be done + // as part of the offline conversion implementation. + if (batch_size <= cached_batch_size) { + // First case: first compatible engine found + // Second case: smaller batch size engine found + if ((smallest_batch_size == -1) || + (cached_batch_size < smallest_batch_size)) { + smallest_batch_size = cached_batch_size; + // Overwrite batch size for output + for (int i = 0; i < engine_input_shapes->size(); i++) { + (*engine_input_shapes)[i].set_dim(0, smallest_batch_size); + } + } } } - return smallest_engine; + return (smallest_batch_size != -1); } void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, @@ -272,25 +383,20 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, ExecuteCalibration(ctx, helper); return; } - const int smallest_engine = GetEngineBatch(ctx); - if (smallest_engine < 0) { - LOG(WARNING) << "Failed to get engine batch, running native segment for " - << name(); - ExecuteNativeSegment(ctx, helper); - return; + // Get shapes of inputs to engine. + std::vector input_shapes; + for (int i = 0; i < ctx->num_inputs(); ++i) { + input_shapes.emplace_back(ctx->input(i).shape()); } - - const int num_batch = ctx->input(0).shape().dim_size(0); - auto& engine_ctx_pair = GetEngine(smallest_engine, ctx); - auto& trt_engine_ptr = engine_ctx_pair.first; - if (!trt_engine_ptr) { - LOG(WARNING) << "Engine retrieval for batch size " << num_batch + EngineContext* engine_context = GetEngine(input_shapes, ctx); + if (!engine_context->cuda_engine) { + LOG(WARNING) << "Engine retrieval for input shapes: " + << TensorShapeUtils::ShapeListString(input_shapes) << " failed. Running native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } - const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(), - engine_ctx_pair.second.get()); + const bool retry = ExecuteTrtEngine(ctx, engine_context); if (retry) { LOG(WARNING) << "Failed to execute engine, " << "retrying with native segment for " << name(); @@ -299,18 +405,19 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, } } -bool TRTEngineOp::ExecuteTrtEngine( - OpKernelContext* ctx, const int num_batch, - nvinfer1::ICudaEngine* trt_engine_ptr, - nvinfer1::IExecutionContext* trt_execution_context_ptr) { +bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, + EngineContext* engine_context) { VLOG(1) << "Executing TRT engine: " << name(); + auto& cuda_engine = engine_context->cuda_engine; const bool kRetry = true; + // All inputs must have the same batch size, so just get it from the first + // input. + const int num_batch = ctx->input(0).shape().dim_size(0); const int num_binding = ctx->num_inputs() + ctx->num_outputs(); std::vector buffers(num_binding); for (int i = 0; i < ctx->num_inputs(); i++) { const string input_name = StrCat(kInputPHName, i); - const int binding_index = - trt_engine_ptr->getBindingIndex(input_name.c_str()); + const int binding_index = cuda_engine->getBindingIndex(input_name.c_str()); if (binding_index == -1) { LOG(ERROR) << "Input node not found, at " << input_name; return kRetry; @@ -323,7 +430,7 @@ bool TRTEngineOp::ExecuteTrtEngine( << " vs " << input_shape.dim_size(0); return kRetry; } - auto dtype = trt_engine_ptr->getBindingDataType(binding_index); + auto dtype = cuda_engine->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = (void*)(input_tensor.flat().data()); @@ -346,13 +453,12 @@ bool TRTEngineOp::ExecuteTrtEngine( for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor const string output_name = StrCat(kOutputPHName, i); - const int binding_index = - trt_engine_ptr->getBindingIndex(output_name.c_str()); + const int binding_index = cuda_engine->getBindingIndex(output_name.c_str()); Tensor* output_tensor = nullptr; TensorShape output_shape; if (binding_index != -1) { - auto dims = trt_engine_ptr->getBindingDimensions(binding_index); + auto dims = cuda_engine->getBindingDimensions(binding_index); std::vector trt_shape(dims.nbDims + 1); trt_shape[0] = num_batch; for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; @@ -374,7 +480,7 @@ bool TRTEngineOp::ExecuteTrtEngine( // TODO(aaroey): ideally we should retry, fix this. return !kRetry; } - auto dtype = trt_engine_ptr->getBindingDataType(binding_index); + auto dtype = cuda_engine->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = @@ -402,9 +508,12 @@ bool TRTEngineOp::ExecuteTrtEngine( ->implementation() ->GpuStreamMemberHack())); + // nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex + // for it. + tensorflow::mutex_lock lock(engine_context->mu); // TODO(jie): trt enqueue does not return error - auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream, - nullptr); + auto ret = engine_context->execution_context->enqueue(num_batch, &buffers[0], + *stream, nullptr); if (!ret) { LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name(); return kRetry; @@ -414,50 +523,45 @@ bool TRTEngineOp::ExecuteTrtEngine( return !kRetry; } -TRTEngineOp::~TRTEngineOp() { - // We need to manually destroy the engine and execution context before - // the allocator is destructed. - for (auto& eng : engine_map_) { - eng.second.first.reset(); - eng.second.second.reset(); +EngineContext* TRTEngineOp::GetEngine( + const std::vector& input_shapes, OpKernelContext* ctx) { + static EngineContext empty_context; + tensorflow::mutex_lock lock(engine_mutex_); + // TODO(tmorris): using first input to get batch size - is this reliable? + const int batch_size = input_shapes[0].dim_size(0); + + // Get engine cache + TRTEngineCacheResource* cache_res = nullptr; + auto status = ctx->resource_manager()->LookupOrCreate( + "TRTEngineCache", funcdef_name_, &cache_res, + {[this, ctx](TRTEngineCacheResource** cr) -> tensorflow::Status { + *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_); + return Status::OK(); + }}); + if (!status.ok()) { + ctx->SetStatus(status); + return &empty_context; } - allocator_.reset(); -} - -nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) { - if (allocator_) return allocator_.get(); - auto device = ctx->device(); - auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes()); - if (!alloc) { - LOG(ERROR) << "Can't find device allocator for gpu device " - << device->name(); - return nullptr; + tensorflow::core::ScopedUnref sc(cache_res); + auto& cache = cache_res->cache_; + auto allocator = cache_res->allocator_.get(); + if (allocator == nullptr) { + return &empty_context; } - allocator_.reset(new TRTDeviceAllocator(alloc)); - return allocator_.get(); -} - -TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, - OpKernelContext* ctx) { - static EngineCtxPair null_pair = { - TrtUniquePtrType(nullptr), - TrtUniquePtrType(nullptr)}; - // TODO(sami): This method needs to be re-written to use resource manager and - // with LRU mechanism option. - tensorflow::mutex_lock lock(engine_mutex_); + // Handle the static engine case. For static engines, the cache will have a + // single element containing the only engine. if (static_engine_) { - if (engine_map_.size()) { - if (engine_map_.begin()->first >= batch_size) { - return engine_map_.begin()->second; + if (cache.size()) { + // Batch size of engine must be >= the input batch size + // TODO(tmorris): use match compatible function? + if (cache.begin()->first[0].dim_size(0) >= batch_size) { + return cache.begin()->second.get(); } - return null_pair; + return &empty_context; } + TrtUniquePtrType infer(nvinfer1::createInferRuntime(logger)); - auto allocator = GetAllocator(ctx); - if (allocator == nullptr) { - return null_pair; - } infer->setGpuAllocator(allocator); TrtUniquePtrType static_engine( infer->deserializeCudaEngine(serialized_segment_.c_str(), @@ -465,62 +569,87 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, PluginFactoryTensorRT::GetInstance())); auto raw_static_engine = static_engine.get(); const auto max_batch_size = raw_static_engine->getMaxBatchSize(); - engine_map_[max_batch_size] = { - std::move(static_engine), - TrtUniquePtrType( - raw_static_engine->createExecutionContext())}; + // Static engine will have max_batch_size for batch size so that all inputs + // will map to this single engine. + std::vector engine_input_shapes(input_shapes); + for (int i = 0; i < engine_input_shapes.size(); i++) { + // TODO(tmorris): will all inputs have batch size as first dimension?? + engine_input_shapes[i].set_dim(0, max_batch_size); + } + // TODO(laigd): here we assume engine_input_shapes matches the actual input + // shapes of the engine, we should verify that. + cache.emplace(engine_input_shapes, + absl::make_unique( + std::move(static_engine), + TrtUniquePtrType( + raw_static_engine->createExecutionContext()))); // Runtime is safe to delete after engine creation serialized_segment_.clear(); if (max_batch_size < batch_size) { - return null_pair; + return &empty_context; } - return engine_map_.at(max_batch_size); + return cache.at(engine_input_shapes).get(); } // static_engine_ // Handle the dynamic engine case. - auto engine_it = engine_map_.find(batch_size); - if (engine_it == engine_map_.end() && - engine_map_.size() < (size_t)max_cached_engines_) { - nvinfer1::IGpuAllocator* allocator = nullptr; - allocator = GetAllocator(ctx); - if (allocator == nullptr) { - return null_pair; - } - std::vector shapes; - for (int i = 0; i < ctx->num_inputs(); ++i) { - shapes.emplace_back(ctx->input(i).shape()); + // See if there is a compatible engine cached. The batch size should be <= the + // cached batch size. + std::vector engine_input_shapes; + const bool matched_successfully = + GetCompatibleCachedEngine(input_shapes, &engine_input_shapes); + // If matched, use that engine. Otherwise, we will look in cache for that + // exact shape and possibly create a new engine if it is not in cache. + if (!matched_successfully) { + engine_input_shapes = input_shapes; + if (!cached_engine_batches_.empty()) { + // If user has explicitly defined cached_engine_batches, we should + // warn them that their input was non-compatible (batch size too high) + LOG(WARNING) << "No compatible cached engine was found for batch size: " + << batch_size << ". A new engine will be created."; + cached_engine_batches_.push_back(batch_size); } + } + + if (!cache.count(engine_input_shapes)) { TrtUniquePtrType engine; bool convert_successfully = false; LOG(INFO) << "Building a new TensorRT engine for " << name() - << " with batch size " << batch_size; + << " input shapes: " + << TensorShapeUtils::ShapeListString(engine_input_shapes); + // Convert to partial shapes + std::vector partial_shapes; + for (int i = 0; i < engine_input_shapes.size(); i++) { + partial_shapes.emplace_back(engine_input_shapes[i]); + } // Up to this point, calibrator_ can never be empty, since otherwise it // means calibration_mode_ is true and this path won't get executed. auto status = convert::ConvertGraphDefToEngine( - segment_graph_, precision_mode_, batch_size, workspace_size_, shapes, - &logger, allocator, calibrator_.get(), &engine, use_calibration_, - &convert_successfully); + segment_graph_, precision_mode_, batch_size, workspace_size_, + partial_shapes, &logger, allocator, calibrator_.get(), &engine, + use_calibration_, &convert_successfully); if (!status.ok()) { if (convert_successfully) { // This means it fail to build the engine even when the network is built // successfully, probably due to internal issues. In this case we don't // retry in the future. - engine_map_[batch_size] = {nullptr, nullptr}; + cache.emplace(engine_input_shapes, absl::make_unique()); } LOG(WARNING) << "Engine creation for batch size " << batch_size << " failed " << status; - return null_pair; + return &empty_context; } VLOG(1) << "Conversion is done"; TrtUniquePtrType exec_context( engine->createExecutionContext()); - engine_map_[batch_size] = {std::move(engine), std::move(exec_context)}; + cache.emplace(engine_input_shapes, + absl::make_unique(std::move(engine), + std::move(exec_context))); } - return engine_map_.at(batch_size); + return cache.at(engine_input_shapes).get(); } tensorflow::Status TRTEngineOp::AllocateCalibrationResources( - OpKernelContext* ctx, TRTCalibrationResource** cr) { + OpKernelContext* ctx, SerializableResourceBase** cr) { auto cres = new TRTCalibrationResource(); *cr = cres; // Get the allocator. @@ -536,7 +665,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( const int batch_size = ctx->input(0).dim_size(0); const int num_inputs = ctx->num_inputs(); std::vector shapes; - dev_tensors_.resize(num_inputs); + cres->device_tensors_.resize(num_inputs); VLOG(1) << " Constructing calibrator"; for (int i = 0; i < num_inputs; i++) { // allocate workspace on device for inputs @@ -544,19 +673,19 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( shapes.emplace_back(t.shape()); Tensor* device_tensor; TF_RETURN_IF_ERROR(ctx->allocate_persistent( - t.dtype(), t.shape(), &dev_tensors_.at(i), &device_tensor)); + t.dtype(), t.shape(), &cres->device_tensors_.at(i), &device_tensor)); CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); void* device_address = GetTensorAddress(device_tensor); if (device_address == nullptr) { return tensorflow::errors::InvalidArgument( "Unsupported data type encountered in input ", i); } - device_buffers_.emplace( + cres->device_buffers_.emplace( StrCat(kInputPHName, i), std::pair(device_address, device_tensor->TotalBytes())); } cres->calibrator_.reset( - new TRTInt8Calibrator(device_buffers_, batch_size, name())); + new TRTInt8Calibrator(cres->device_buffers_, batch_size, name())); const string label(name()); auto segment_graph = &segment_graph_; const int platform_gpu_id = @@ -585,9 +714,10 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( // TODO(aaroey): maybe setting the max batch size using the python // calibration wrapper class. auto s = convert::ConvertGraphDefToEngine( - *segment_graph, INT8MODE, cres->calibrator_->getBatchSize(), - workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(), - cres->calibrator_.get(), &cres->engine_, + *segment_graph, TrtPrecisionMode::INT8, + cres->calibrator_->getBatchSize(), workspace_size_bytes, shapes, + &cres->logger_, cres->allocator_.get(), cres->calibrator_.get(), + &cres->engine_, /*use_calibration=*/true, /*convert_successfully=*/nullptr); if (!s.ok()) { diff --git a/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc b/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..59da73f5efc8eedc20c35cf35cb1eae6cda136c9 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/ops/get_serialized_resource_op.cc @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +REGISTER_OP("GetSerializedResourceOp") + .Input("container: string") + .Input("resource_name: string") + .Output("serialized_resource: string") + .SetShapeFn(shape_inference::ScalarShape) + .SetIsStateful() + .Doc(R"doc( +Gets a resource from a container managed by the resource manager and returns +its serialized representation. +)doc"); + +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc similarity index 80% rename from tensorflow/contrib/tensorrt/ops/trt_engine_op.cc rename to tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc index 92405906eb76b043bc08b68e25e16ab40197dddf..b84d2fe0b8cef3475f2a7d0f5383d5e11cde099a 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc @@ -28,16 +28,22 @@ namespace shape_inference { extern Status TRTEngineOpShapeInference(InferenceContext* c); } +// NOTE: please try NOT to add/modify/remove attributes or inputs/outputs to the +// list below, this will break backward compatibility! +// +// TODO(laigd): consider making this op stateful. The only problem is it uses TF +// function which has to be stateless, but we can use function library as the +// key to cache the instantiated functions for different executor subgraphs. REGISTER_OP("TRTEngineOp") .Attr("serialized_segment: string") .Attr("input_shapes: list(shape)") .Attr("output_shapes: list(shape)") .Attr("segment_funcdef_name: string") - .Attr("InT: list({int8,float16,float32})") - .Attr("OutT: list({int8,float16,float32})") + .Attr("InT: list({int8,float16,float32,int32})") + .Attr("OutT: list({int8,float16,float32,int32})") .Attr("static_engine: bool = true") .Attr("fixed_input_size: bool = true") - .Attr("cached_engine_batches: list(int) = []") + .Attr("cached_engine_batches: list(int) >= 0 = []") .Attr("max_cached_engines_count: int = 1") .Attr("workspace_size_bytes: int") .Attr("precision_mode: {'FP32', 'FP16', 'INT8'}") diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc index 062f86e8bb4dc753925e4e2baf0bc80a5312a94f..a4341c530fffca88c82813cc2ace2c0ae1df5345 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" + #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" + +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h similarity index 92% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h index 754920b60ca7439513a91ad0354833a2482b29c1..f495d857037c79a1783f8eb232fb57c20e229169 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ #include #include @@ -71,4 +71,4 @@ class PluginTensorRT : public nvinfer1::IPlugin { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc index cccc91226265ed139fb8db0b71c40b868f729562..871fb1210bd495dc3f5e8153bb6c3a361bf569f5 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h similarity index 91% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h index bbae9fb65c22cf69d2e7954436fd04dd16f7f6c8..9aa99a40b80de92a4d9b9ad36e88e693b8aa42dc 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -99,4 +99,4 @@ class TrtPluginRegistrar { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY_H_ diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc index 129bdcdbc2f8d9d5215f45f381bcadf35e4fa75e..7d9c465c22beed0e252cbc26d6c533a0789d4f49 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc similarity index 94% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc index a8f60886c03c174a612e7a135b6eb7bb7cb9997a..f3d6b4ff476139693a5251ddf58a3200d8af8efc 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h" #include #if GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h similarity index 82% rename from tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h rename to tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h index 274ce42fec9283c643004d45fba461879fc5f2dc..e5eff15c19694093c7a5ea933a41375e8e01c8b9 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_utils.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA @@ -43,4 +43,4 @@ string ExtractOpName(const void* serial_data, size_t serial_length, #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py similarity index 84% rename from tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py rename to tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py index 31a313182be9a2fca7457a539670dbc911ccabb1..86bfabf99e08a8e447a28504c72eebca4d3a582c 100644 --- a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py +++ b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py @@ -22,13 +22,13 @@ import platform if platform.system() != "Windows": # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import * + from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import * - from tensorflow.contrib.util import loader + from tensorflow.python.framework import load_library from tensorflow.python.platform import resource_loader # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - _trt_engine_op = loader.load_op_library( - resource_loader.get_path_to_datafile("_trt_engine_op.so")) + _trt_ops = load_library.load_op_library( + resource_loader.get_path_to_datafile("_trt_ops.so")) else: raise RuntimeError("Windows platforms are not supported") diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc similarity index 93% rename from tensorflow/contrib/tensorrt/segment/segment.cc rename to tensorflow/compiler/tf2tensorrt/segment/segment.cc index 6abc5226ccf96e472df77269bee6186726e5768d..4a8a4ac7589a4b68b129e8e88ee999e8a2495728 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include #include #include #include -#include "tensorflow/contrib/tensorrt/segment/union_find.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -32,8 +33,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace segment { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; // A simple graph representation to mirror tensorflow::Graph. This structure // helps saving memory since segmenter modifies the graph in place, preventing @@ -225,6 +226,24 @@ SimpleGraph::~SimpleGraph() { for (auto x : edges_) delete x; } +// Define comparison functions for std::set with pointer keys so that behavior +// is deterministic. When using std::set with pointer key types, the items are +// sorted by pointer address which is non-deterministic. This can cause issues +// for INT8 mode because the graph is converted twice and non-determinism may +// cause a mismatch between the calibration tables of the conversions. +struct SimpleEdgePtrCompare { + bool operator()(const SimpleEdge* lhs, const SimpleEdge* rhs) const { + return lhs->id() < rhs->id(); + } +}; + +struct NodePtrCompare { + bool operator()(const tensorflow::Node* lhs, + const tensorflow::Node* rhs) const { + return lhs->name() < rhs->name(); + } +}; + namespace { // Copied from TF ReverseDFS, which only works for tensorflow::Graph. @@ -476,7 +495,7 @@ tensorflow::Status SegmentGraph( // nodes. Iterate since combining two nodes may unblock other // combining. while (true) { - std::set contract_edges; + std::set contract_edges; for (const SimpleEdge* out_edge : node->out_edges()) { VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; @@ -530,7 +549,7 @@ tensorflow::Status SegmentGraph( // A map from the segment identifier (currently the name of the root node of // the segment tree) to the segment nodes set. - std::map> sg_map; + std::map> sg_map; // A map from the segment identifier (currently the name of the root node of // the segment tree) to the device names that the nodes in the segment are @@ -566,7 +585,8 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 2 --------------------------------- // Remove ineligible input/output nodes. for (auto& itr : sg_map) { - std::set& segment_nodes = itr.second; + std::set& segment_nodes = + itr.second; VLOG(1) << "Segment original size: " << segment_nodes.size(); while (true) { std::deque in_nodes_que, out_nodes_que; @@ -618,8 +638,9 @@ tensorflow::Status SegmentGraph( bool is_input_nodes, std::deque* que) { // Run a BFS on the queue to find all the input/output nodes. - std::set visited; - std::set logged(que->begin(), que->end()); + std::set visited; + std::set logged(que->begin(), + que->end()); while (!que->empty()) { auto node = que->front(); que->pop_front(); @@ -653,9 +674,11 @@ tensorflow::Status SegmentGraph( // --------------------------------- Step 3 --------------------------------- // Convert the segments into the expected return format for (const auto& itr : sg_map) { - const std::set& segment_nodes = itr.second; + const string& segment_root = itr.first; + // Return format does not require set comparator. + std::set segment_nodes(itr.second.begin(), itr.second.end()); if (VLOG_IS_ON(1)) { - string s = "parent=" + itr.first + ":"; + string s = "parent=" + segment_root + ":"; for (auto node : segment_nodes) s += " " + node->name(); VLOG(1) << "Segment " << segments->size() << ": " << s; } @@ -668,12 +691,10 @@ tensorflow::Status SegmentGraph( } // TODO(sami): Make segmenter placement aware once trtscopes are in place - std::set segment_node_names; - for (auto node : itr.second) segment_node_names.insert(node->name()); - const auto& dev_itr = device_maps.find(itr.first); + const auto& dev_itr = device_maps.find(segment_root); if (dev_itr == device_maps.end() || dev_itr->second.empty()) { VLOG(1) << "No device assigned to segment " << segments->size(); - segments->emplace_back(std::make_pair(segment_node_names, string())); + segments->emplace_back(std::make_pair(segment_nodes, string())); } else if (dev_itr->second.size() > 1) { string s("Segment "); StrAppend(&s, segments->size(), " has multiple devices attached: "); @@ -682,10 +703,10 @@ tensorflow::Status SegmentGraph( } LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin()); segments->emplace_back( - std::make_pair(segment_node_names, *(dev_itr->second.begin()))); + std::make_pair(segment_nodes, *(dev_itr->second.begin()))); } else { segments->emplace_back( - std::make_pair(segment_node_names, *(dev_itr->second.begin()))); + std::make_pair(segment_nodes, *(dev_itr->second.begin()))); } } if (VLOG_IS_ON(1)) { diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/compiler/tf2tensorrt/segment/segment.h similarity index 83% rename from tensorflow/contrib/tensorrt/segment/segment.h rename to tensorflow/compiler/tf2tensorrt/segment/segment.h index b9693aad1b764515459db6833b05221ea5b3a2d1..9a0ccc9aef475edfb0ffb83a2be21d4d4ca0e028 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ #include #include @@ -29,10 +29,10 @@ namespace tensorflow { namespace tensorrt { namespace segment { -// Vector of segments, each entry contains a set of node names and a device name -// in the segment. -// TODO(aaroey): use node pointer instead of node name. -using SegmentNodesVector = std::vector, string>>; +// Vector of segments, each entry contains a set of node pointers and a device +// name in the segment. +using SegmentNodesVector = + std::vector, string>>; struct SegmentOptions { // Segment must contain at least this many nodes. @@ -60,4 +60,4 @@ tensorflow::Status SegmentGraph( } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc similarity index 97% rename from tensorflow/contrib/tensorrt/segment/segment_test.cc rename to tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index 4805ef9c61a7784a1c08cf5eaf504691bc9dbedc..58512d3b09d7c6f523710bc09843c628a5838b53 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/segment/segment.h" +#include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -75,7 +75,10 @@ class SegmentTest : public ::testing::Test { const std::vector>& expected_segments) { EXPECT_EQ(expected_segments.size(), segments.size()); for (int i = 0; i < segments.size(); ++i) { - const auto& segment_node_names = segments[i].first; + std::set segment_node_names; + for (const Node* node : segments[i].first) { + segment_node_names.insert(node->name()); + } const auto& expected = expected_segments[i]; for (const auto& name : expected) { EXPECT_TRUE(segment_node_names.count(name)) diff --git a/tensorflow/contrib/tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h similarity index 92% rename from tensorflow/contrib/tensorrt/segment/union_find.h rename to tensorflow/compiler/tf2tensorrt/segment/union_find.h index 1c64ebbb0ae532a4776ab8963515d19fd3b23b4c..6458ae692fd7c922b5fc3bea2e55b613447dbde0 100644 --- a/tensorflow/contrib/tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ namespace tensorflow { namespace tensorrt { @@ -76,4 +76,4 @@ UnionFind* UnionFind::FindRoot() { } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc similarity index 100% rename from tensorflow/contrib/tensorrt/tensorrt_test.cc rename to tensorflow/compiler/tf2tensorrt/tensorrt_test.cc diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc similarity index 97% rename from tensorflow/contrib/tensorrt/test/utils.cc rename to tensorflow/compiler/tf2tensorrt/utils/test_utils.cc index 276308b3a0a6ce864969afb0179c6a3f00d6b70b..3bcca99afbff8b84d2dd628ae9211ee94e86af2a 100644 --- a/tensorflow/contrib/tensorrt/test/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" #include #include diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h similarity index 89% rename from tensorflow/contrib/tensorrt/test/utils.h rename to tensorflow/compiler/tf2tensorrt/utils/test_utils.h index 4bb4120206cfaae70107e55d1818e3af2f02717a..bcd628b62f0320f7ce9dfe6240316d876f1d5a20 100644 --- a/tensorflow/contrib/tensorrt/test/utils.h +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -41,4 +41,4 @@ string GetTestValue(const string& label); } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc similarity index 98% rename from tensorflow/contrib/tensorrt/resources/trt_allocator.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc index 7a2e93414aed56525eaeac876cdac20404bcf6ab..1636cdc30c4df157ed124b160449af645f917252 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h similarity index 93% rename from tensorflow/contrib/tensorrt/resources/trt_allocator.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h index f857a9de055ee7668f0bf9bc97e030354505081b..59ffb42bad348c78cde32035aff8c7081528b3a6 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ #include @@ -81,4 +81,4 @@ class TRTDeviceAllocator : public TRTBaseAllocator { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc similarity index 80% rename from tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc index ad6b1d7d4c57d696d3dee3b479733e152e669211..e457c64928e5df84c7e2726ba3621420f013dbc9 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/core/platform/test.h" @@ -48,11 +48,14 @@ TEST(TRTAllocatorTest, Align) { 513ul, 700ul, 12345ul, 1ul << 32}) { for (uint64_t alignment = 1; alignment <= space * 4; alignment *= 2) { for (const uintptr_t ptr_val : - {1ul, alignment == 1 ? 1ul : alignment - 1, alignment, alignment + 1, - alignment + (alignment / 2)}) { + {static_cast(1), + alignment == 1 ? static_cast(1) : alignment - 1, + alignment, alignment + 1, alignment + (alignment / 2)}) { if (ptr_val % alignment == 0) { for (const uint64_t size : - {1ul, space == 1 ? 1ul : space - 1, space, space + 1}) { + {static_cast(1), + space == 1 ? static_cast(1) : space - 1, space, + space + 1}) { EXPECT_EQ(space >= size, RunTest(alignment, size, ptr_val, space)); } } else { @@ -62,8 +65,10 @@ TEST(TRTAllocatorTest, Align) { EXPECT_TRUE( RunTest(alignment, space - diff, ptr_val + diff, space - diff)); for (const uint64_t size : - {1ul, space - diff > 1 ? space - diff - 1 : 1ul, space - diff, - space - diff + 1, space - 1}) { + {static_cast(1), + space - diff > 1 ? space - diff - 1 + : static_cast(1), + space - diff, space - diff + 1, space - 1}) { EXPECT_EQ(space - diff >= size, RunTest(alignment, size, ptr_val, space)); } diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc similarity index 98% rename from tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc index dab1dd9343be7d5b033a3e04bf0b49fbbf37e9e5..bf111d3a2ee2fbec9151d12bbb6ff7181761c2aa 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" #include #include diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h similarity index 93% rename from tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h index 65466c9741989fda5f82fc27d813d026f35fe386..10587e99624acfb97730bbbd9dfbcde020ffc669 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ #include #include @@ -96,4 +96,4 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { #endif #endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc similarity index 91% rename from tensorflow/contrib/tensorrt/log/trt_logger.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc index dda0dc9e712eb726800abfb6084f4f708d04825b..f454f55f2cb4ee65b97891ae8dd58d809d36f099 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -26,6 +26,9 @@ namespace tensorrt { void Logger::log(Severity severity, const char* msg) { // Suppress info-level messages switch (severity) { +#if NV_TENSORRT_MAJOR >= 5 && NV_TENSORRT_MINOR >= 1 + case Severity::kVERBOSE: +#endif case Severity::kINFO: { // Mark TRT info messages as debug! VLOG(2) << name_ << " " << msg; break; diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h similarity index 86% rename from tensorflow/contrib/tensorrt/log/trt_logger.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_logger.h index 96ccacb791e40143c5c4d9d691bb353702f9a28b..22f4de970a80765b0e1e7e8816134d83aaec7c73 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ #include "tensorflow/core/platform/types.h" @@ -41,4 +41,4 @@ class Logger : public nvinfer1::ILogger { #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..09c47b36b0ad8074e749342e7d08f139da7ea1f4 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -0,0 +1,192 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ + +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/errors.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +template +class LRUCache { + public: + typedef Value value_type; + typedef Key key_type; + typedef HashFunction hasher; + typedef typename std::unordered_map map_type; + typedef typename map_type::iterator iterator; + typedef typename map_type::const_iterator const_iterator; + + LRUCache() : capacity_(0) {} + explicit LRUCache(size_t capacity) : capacity_(capacity) {} + + size_t capacity() const { return capacity_; } + + void reserve(size_t capacity) { + capacity_ = capacity; + DiscardOld(); + } + + size_t size() const { return objects_.size(); } + + size_t count(const key_type& key) const { return objects_.count(key); } + + value_type& at(const key_type& key) { return Touch(key); } + + const_iterator begin() const { return objects_.begin(); } + const_iterator end() const { return objects_.end(); } + + iterator begin() { return objects_.begin(); } + iterator end() { return objects_.end(); } + + template + std::pair emplace(Args&&... args) { + DiscardOld(1); + std::pair result = + objects_.emplace(std::forward(args)...); + key_type key = result.first->first; + if (result.second) { + keys_.push_front(key); + } else { + TouchNoCheck(key); // The key must exist in this case. + } + return result; + } + + private: + std::unordered_map objects_; + std::list keys_; + size_t capacity_; + value_type not_found_value_; + + value_type& Touch(const key_type& key) { + // Check that the key exists, and let it return std::out_of_range error if + // not. + value_type& value = objects_.at(key); + TouchNoCheck(key); + return value; + } + + void TouchNoCheck(const key_type& key) { + auto rank = std::find(keys_.begin(), keys_.end(), key); + if (rank != keys_.begin()) { + keys_.erase(rank); + keys_.push_front(key); + } + } + + // Creates n free positions in cache + tensorflow::Status DiscardOld(size_t n = 0) { + if (n > capacity_) { + return tensorflow::errors::Internal( + "Insufficient capacity in cache (capacity = ", capacity_, + ", requested ", n, ")"); + } + while (objects_.size() > (capacity_ - n)) { + key_type discard_key = keys_.back(); + keys_.pop_back(); + objects_.erase(discard_key); + } + return tensorflow::Status::OK(); + } +}; + +// Define a hash function for vector because it is used as the key +// for the engine cache. +struct VectorTensorShapeHasher { + std::size_t operator()( + const std::vector& key) const { + return std::hash()(TensorShapeUtils::ShapeListString(key)); + } +}; + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +struct EngineContext { + EngineContext() {} // Creates an empty context. + EngineContext( + TrtUniquePtrType&& input_cuda_engine, + TrtUniquePtrType&& input_execution_context) + : cuda_engine(std::move(input_cuda_engine)), + execution_context(std::move(input_execution_context)) {} + + mutex mu; + TrtUniquePtrType cuda_engine; + TrtUniquePtrType execution_context + GUARDED_BY(mu); +}; + +class TRTEngineCacheResource : public tensorflow::ResourceBase { + public: + TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity) + : cache_(capacity) { + auto device = ctx->device(); + auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes()); + if (!alloc) { + LOG(ERROR) << "Can't find device allocator for gpu device " + << device->name(); + allocator_ = nullptr; + } else { + allocator_.reset(new TRTDeviceAllocator(alloc)); + } + } + + string DebugString() const override { + std::stringstream oss; + using std::dec; + using std::endl; + using std::hex; + oss << "TRTEngineCacheResource: "; + oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", "; + oss << "LRUCache = " << hex << &cache_ << dec << endl; + oss << "Containing " << cache_.size() << " entries: " << endl; + for (const auto& item : cache_) { + oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex + << "ICudaEngine: " << item.second.get()->cuda_engine.get() << ", " + << "IExecutionContext: " << item.second.get()->execution_context.get() + << dec << endl; + } + return oss.str(); + } + + // Keep device allocator for TRT. + std::unique_ptr allocator_; + + // Declare cache after allocator so that it is destroyed before allocator is. + LRUCache, std::unique_ptr, + VectorTensorShapeHasher> + cache_; +}; + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0aa5eb8f7d4ad062c2d8622fa5aa55f823f80dd5 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace tensorrt { + +TEST(LRUCacheTest, Basic) { + LRUCache> cache; + cache.reserve(2); + // Insert 10 + cache.emplace(10, 100); + EXPECT_EQ(cache.size(), 1); + EXPECT_EQ(cache.count(10), 1); + EXPECT_EQ(cache.at(10), 100); + EXPECT_EQ(cache.count(100), 0); + // Insert 20 + cache.emplace(20, 200); + EXPECT_EQ(cache.size(), 2); + EXPECT_EQ(cache.count(10), 1); + EXPECT_EQ(cache.count(20), 1); + EXPECT_EQ(cache.at(10), 100); + EXPECT_EQ(cache.at(20), 200); + EXPECT_EQ(cache.count(100), 0); + EXPECT_EQ(cache.count(200), 0); + // Insert 30, Evicting 10 + cache.emplace(30, 300); + EXPECT_EQ(cache.count(10), 0); + EXPECT_EQ(cache.count(20), 1); + EXPECT_EQ(cache.count(30), 1); + // Touch 20 + cache.at(20); + // Insert 40, Evicting 30 + cache.emplace(40, 400); + EXPECT_EQ(cache.count(10), 0); + EXPECT_EQ(cache.count(20), 1); + EXPECT_EQ(cache.count(30), 0); + EXPECT_EQ(cache.count(40), 1); +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc similarity index 96% rename from tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc rename to tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc index 9c3698e5d1cc5d6d8d31a8fcaf03d103f1e1915d..0a72a88bc740101bcbadb40bfe106a5b8d284bbf 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h similarity index 87% rename from tensorflow/contrib/tensorrt/resources/trt_resource_manager.h rename to tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h index 19f39e6d3db1571573fb290dd2c30fd43ea604ef..03879ffff2fa724b05cb1919753e4aaa99e2e702 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ #include #include @@ -42,4 +42,4 @@ class TRTResourceManager { } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc new file mode 100644 index 0000000000000000000000000000000000000000..37f7fe99fbb2b9e121953fc0de211db1bbf34b7a --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +TRTCalibrationResource::~TRTCalibrationResource() { + VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + builder_.reset(); + engine_.reset(); + // We need to manually destroy the builder and engine before the allocator + // is destroyed. + allocator_.reset(); +} + +string TRTCalibrationResource::DebugString() const { + std::stringstream oss; + using std::dec; + using std::endl; + using std::hex; + oss << " Calibrator = " << hex << calibrator_.get() << dec << endl + << " Builder = " << hex << builder_.get() << dec << endl + << " Engine = " << hex << engine_.get() << dec << endl + << " Logger = " << hex << &logger_ << dec << endl + << " Allocator = " << hex << allocator_.get() << dec << endl + << " Thread = " << hex << thr_.get() << dec << endl; + return oss.str(); +} + +Status TRTCalibrationResource::SerializeToString(string* serialized) { + calibrator_->waitAndSetDone(); + thr_->join(); + *serialized = calibrator_->getCalibrationTableAsString(); + if (!serialized->size()) { + return tensorflow::errors::Unknown("Calibration table is empty."); + } + return Status::OK(); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h new file mode 100644 index 0000000000000000000000000000000000000000..5e8d4b3b738df09b0c2ea82dcc06e9b23a708385 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +class SerializableResourceBase : public tensorflow::ResourceBase { + public: + virtual Status SerializeToString(string* serialized) = 0; +}; + +class TRTCalibrationResource : public SerializableResourceBase { + public: + ~TRTCalibrationResource() override; + + string DebugString() const override; + + Status SerializeToString(string* serialized) override; + + // Lookup table for temporary staging areas of input tensors for calibration. + std::unordered_map> device_buffers_; + + // Temporary staging areas for calibration inputs. + std::vector device_tensors_; + + std::unique_ptr calibrator_; + TrtUniquePtrType builder_; + TrtUniquePtrType engine_; + std::unique_ptr allocator_; + tensorflow::tensorrt::Logger logger_; + // TODO(sami): Use threadpool threads! + std::unique_ptr thr_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 25a84fb1b6609106213231db1ca1ce54da8bd960..02de95141da1a28e59d3155742217efdf163e8dd 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,6 +1,6 @@ licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test") package_group( name = "internal", @@ -204,6 +204,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", @@ -224,6 +225,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], alwayslink = 1, ) @@ -244,6 +246,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -314,11 +317,13 @@ tf_cc_test( ":tf2xla_util", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -445,14 +450,9 @@ cc_library( ], deps = [ "//tensorflow/compiler/jit:flags", - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", + "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", ], ) @@ -673,8 +673,31 @@ cc_library( name = "side_effect_util", srcs = ["side_effect_util.cc"], hdrs = ["side_effect_util.h"], + visibility = [":friends"], deps = [ "//tensorflow/core:core_cpu", "@com_google_absl//absl/strings", ], ) + +tf_cuda_cc_test( + name = "fused_batchnorm_reserve_space_test", + size = "medium", + srcs = ["fused_batchnorm_reserve_space_test.cc"], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/compiler/jit", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/algorithm:container", + ], +) diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime.h b/tensorflow/compiler/tf2xla/cpu_function_runtime.h index dfc1e8b8aebcf3142e9f61f60171c6b58634c71d..78970fb39bae7067c7668baa2aec65732b5b2352 100644 --- a/tensorflow/compiler/tf2xla/cpu_function_runtime.h +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.h @@ -104,7 +104,7 @@ class BufferInfo { private: BufferInfo() = default; - enum class Kind : unsigned { + enum class Kind : uint64 { kConstant, kTempBuffer, kEntryParameter, diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 1de85004a51bea464f8f0166511402e5dd85ac14..64fdbbebc65bff4ed0b965fcdd534cc9696472b6 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -18,86 +18,26 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace dump_graph { -namespace { - -struct NameCounts { - mutex counts_mutex; - std::unordered_map counts; -}; - -string MakeUniqueFilename(string name) { - static NameCounts& instance = *new NameCounts; - - // Remove illegal characters from `name`. - for (int i = 0; i < name.size(); ++i) { - char ch = name[i]; - if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') { - name[i] = '_'; - } - } - - int count; - { - mutex_lock lock(instance.counts_mutex); - count = instance.counts[name]++; - } - - string filename = name; - if (count > 0) { - absl::StrAppend(&filename, "_", count); - } - absl::StrAppend(&filename, ".pbtxt"); - return filename; -} - -string WriteTextProtoToUniqueFile( - Env* env, const string& name, const char* proto_type, - const ::tensorflow::protobuf::Message& proto) { - const string& dirname = GetDumpGraphFlags()->tf_dump_graph_prefix; - Status status = env->RecursivelyCreateDir(dirname); - if (!status.ok()) { - LOG(WARNING) << "Failed to create " << dirname << " for dumping " - << proto_type << ": " << status; - return "(unavailable)"; - } - string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name)); - status = WriteTextProto(Env::Default(), filepath, proto); - if (!status.ok()) { - LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath - << " : " << status; - return "(unavailable)"; - } - LOG(INFO) << "Dumped " << proto_type << " to " << filepath; - return filepath; -} - -} // anonymous namespace - string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", - graph_def); + return tensorflow::DumpGraphDefToFile( + name, graph_def, GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpGraphToFile(const string& name, Graph const& graph, const FunctionLibraryDefinition* flib_def) { - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - if (flib_def) { - *graph_def.mutable_library() = flib_def->ToProto(); - } - return DumpGraphDefToFile(name, graph_def); + return tensorflow::DumpGraphToFile(name, graph, flib_def, + GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef); + return tensorflow::DumpFunctionDefToFile( + name, fdef, GetDumpGraphFlags()->tf_dump_graph_prefix); } } // namespace dump_graph diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index c693e42d26712d55852f45c806215fc1f1b9a030..c8341a2c6bb66e43fb00cb660726cf5a1979c992 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -34,6 +34,8 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/strcat.h" using xla::StatusOr; @@ -41,6 +43,26 @@ using xla::StatusOr; namespace tensorflow { namespace functionalize_cond { +bool AncestorNode::operator<(const AncestorNode& other) const { + return (output_tensor.node->id() < other.output_tensor.node->id()) || + (output_tensor.node->id() == other.output_tensor.node->id() && + output_tensor.index < other.output_tensor.index) || + (output_tensor.node->id() == other.output_tensor.node->id() && + output_tensor.index == other.output_tensor.index && + type < other.type); +} + +bool AncestorNode::operator==(const AncestorNode& other) const { + return output_tensor.node->id() == other.output_tensor.node->id() && + output_tensor.index == other.output_tensor.index && type == other.type; +} + +size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const { + size_t h = std::hash()(ancestor.output_tensor.node->id()); + h = Hash64Combine(h, std::hash()(ancestor.output_tensor.index)); + return Hash64Combine(h, std::hash()(static_cast(ancestor.type))); +} + // TODO(jpienaar): Move to OutputTensor. string DebugString(const OutputTensor& tensor) { return absl::StrCat(tensor.node->name(), ":", tensor.index); @@ -145,10 +167,10 @@ size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const { if (map.empty()) return 0; // Compute hash of the front element. auto it = map.begin(); - size_t h = hash()(*it); + size_t h = AncestorNode::Hash()(*it); for (++it; it != map.end(); ++it) { // Combine the has with the different elements in the map. - h = Hash64Combine(h, hash()(*it)); + h = Hash64Combine(h, AncestorNode::Hash()(*it)); } return h; } @@ -229,7 +251,17 @@ string StateMap::CondStateToString(StateMap::CondId id) const { } string StateMap::AncestorStateToString(const Node* node) const { - if (auto id = LookupAncestorId(node)) return NodesToString(*id); + if (auto id = LookupAncestorId(node)) { + return absl::StrCat( + "{", + absl::StrJoin(*id, ",", + [](string* output, const AncestorNode& ancestor) { + absl::StrAppend(output, + ancestor.output_tensor.node->name(), + ":", ancestor.output_tensor.index); + }), + "}"); + } return "{}"; } @@ -247,7 +279,9 @@ class Conditional { Status AddMerge(Node* m); // Constructs an If node from the merge nodes. - Status BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library); + Status BuildAndReplace( + Graph* graph, FunctionLibraryDefinition* library, + std::unordered_map* merge_to_replacement); private: // Extracts the then/else bodies: creates new graphs with the nodes @@ -262,10 +296,15 @@ class Conditional { Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library); // Adds input edges to If node. - Status AddInputEdges(Graph* graph); + Status AddInputEdges( + Graph* graph, + const std::unordered_map& merge_to_replacement); // Adds output edges from If node. - Status AddOutputEdges(Graph* graph); + // Record new output tensor for all Merge nodes in 'merge_to_replacement'. + Status AddOutputEdges( + Graph* graph, + std::unordered_map* merge_to_replacement); // Adds switch node that is part of this conditional. Status AddSwitch(Node* s); @@ -640,7 +679,8 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If", library); + NodeDebugInfo debug_info((*merges_.begin())->def()); + NodeDefBuilder builder(name(), "If", library, &debug_info); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); @@ -719,12 +759,29 @@ Status Conditional::BuildIfNode(Graph* graph, return Status::OK(); } -Status Conditional::AddInputEdges(Graph* graph) { +Status Conditional::AddInputEdges( + Graph* graph, + const std::unordered_map& merge_to_replacement) { VLOG(2) << "AddInputEdges for " << if_node_->name(); int index = 0; // Add predicate input. - graph->AddEdge(const_cast(predicate_.node), predicate_.index, if_node_, - index++); + if (predicate_.node->IsMerge()) { + // If the predicate is a Merge node, we should not use Merge output as + // predicate. Instead, we should use the corresponding If output in + // 'merge_to_replacement'. Otherwise, this Conditional's If node is still + // connected to the predicate Merge node; and when we call + // DeleteReachableAndDeadNodes(), the predicate Merge node and this + // Conditional's If node will be removed. + auto iter = merge_to_replacement.find(predicate_.node); + if (iter == merge_to_replacement.end()) { + return errors::Internal("Cannot find replacement for Merge node ", + predicate_.node->name()); + } + graph->AddEdge(iter->second.node, iter->second.index, if_node_, index++); + } else { + graph->AddEdge(const_cast(predicate_.node), predicate_.index, + if_node_, index++); + } // Add function body inputs. for (auto& arg : cond_arg_nodes_) { if (arg.src_output == Graph::kControlSlot) { @@ -739,7 +796,9 @@ Status Conditional::AddInputEdges(Graph* graph) { return Status::OK(); } -Status Conditional::AddOutputEdges(Graph* graph) { +Status Conditional::AddOutputEdges( + Graph* graph, + std::unordered_map* merge_to_replacement) { VLOG(2) << "AddOutputEdges for " << if_node_->name(); int i = 0; for (Node* node : merges_) { @@ -763,6 +822,10 @@ Status Conditional::AddOutputEdges(Graph* graph) { graph->AddEdge(if_node_, i, dst, dst_input); } } + + // Record corresponding output tensor in 'merge_to_replacement'. + (*merge_to_replacement)[node] = OutputTensor{if_node_, i}; + ++i; } for (Node* n : external_control_outputs_) { @@ -772,8 +835,9 @@ Status Conditional::AddOutputEdges(Graph* graph) { return Status::OK(); } -Status Conditional::BuildAndReplace(Graph* graph, - FunctionLibraryDefinition* library) { +Status Conditional::BuildAndReplace( + Graph* graph, FunctionLibraryDefinition* library, + std::unordered_map* merge_to_replacement) { VLOG(1) << "Build If and replace merge nodes " << NodesToString(this->merges_); if (replaced_) return Status::OK(); @@ -792,8 +856,8 @@ Status Conditional::BuildAndReplace(Graph* graph, } TF_RETURN_IF_ERROR(BuildIfNode(graph, library)); - TF_RETURN_IF_ERROR(AddInputEdges(graph)); - TF_RETURN_IF_ERROR(AddOutputEdges(graph)); + TF_RETURN_IF_ERROR(AddInputEdges(graph, *merge_to_replacement)); + TF_RETURN_IF_ERROR(AddOutputEdges(graph, merge_to_replacement)); TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); // Check that the if_node doesn't feed into itself. @@ -935,6 +999,10 @@ StatusOr FunctionalizeCond::JoinCondStatesMerge( VLOG(4) << "Joining (for merge) " << DebugString(src) << " and " << DebugString(dst); if (state_map_.IsEmpty(dst)) return src; + if (state_map_.IsEmpty(src)) { + return errors::Internal("Merge node ", merge->name(), + " has input that's not in any CondContext."); + } if (state_map_.IsDead(src)) return src; if (state_map_.IsDead(dst)) return dst; @@ -1169,8 +1237,17 @@ Status FunctionalizeCond::DetermineAncestorState(Node* dst) { if (other_id != id && other_id != nullptr) { state.insert(other_id->begin(), other_id->end()); } - if (IsSwitch(src) || IsMerge(src)) { - state.insert(src); + if (IsMerge(src)) { + state.insert({{src, 0}, AncestorNode::AncestorNodeType::kMerge}); + } else if (IsSwitch(src)) { + OutputTensor pred; + // For dead switch nodes, GetSwitchPredicate() will fail, and we use + // the switch node directly as ancestor. + if (GetSwitchPredicate(*src, &pred).ok()) { + state.insert({pred, AncestorNode::AncestorNodeType::kPred}); + } else { + state.insert({{src, 0}, AncestorNode::AncestorNodeType::kSwitch}); + } } return state_map_.GetAncestorId(state); }; @@ -1344,7 +1421,8 @@ Status FunctionalizeCond::FunctionalizeInternal() { Conditional cond(merge_to_predicate_.at(cluster.front()), this, &state_map_); for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge)); - TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_)); + TF_RETURN_IF_ERROR( + cond.BuildAndReplace(graph_, library_, &merge_to_predicate_)); if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); } diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 8525d7af61b4471e53a9ae16b081060bfd234c9c..d85800fb8ee65a354716bf6601c6bc40eca9a10d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -43,6 +43,33 @@ enum class BranchType { kNeither = 3, }; +// When we keep track of which switch/merge node's feed into a node, we record +// 1) predicate for non-dead switch node, +// 2) the switch node itself for dead switch node, +// 3) the merge node itself for merge node. +// Case 1) is an optimization. With this optimization, if there are nodes from +// different switch nodes but those switch nodes have the same predicate, the +// nodes will still have same AncestorState, and they will be clustered into a +// single "If". +struct AncestorNode { + enum class AncestorNodeType { + kPred = 0, + kSwitch = 1, + kMerge = 2, + }; + + OutputTensor output_tensor; + AncestorNodeType type; + + // Compare two AncestorNodes by (node id, index, type). + bool operator<(const AncestorNode& other) const; + bool operator==(const AncestorNode& other) const; + + struct Hash { + size_t operator()(const AncestorNode&) const; + }; +}; + // StateMap is responsible for mapping from each graph Node to // * a CondState, where each CondState is a map from predicate to branch (i,e., // what predicates have to hold or not hold). @@ -68,7 +95,7 @@ class StateMap { using CondId = const CondState*; // Keep track of which switch/merge node's feed into a node's values. - using AncestorState = std::set; + using AncestorState = std::set; // Every unique ID is mapped to a AncestorState. using AncestorId = const AncestorState*; @@ -232,6 +259,9 @@ class FunctionalizeCond { // Mapping from merge nodes to predicate. std::unordered_map merge_to_predicate_; + // Mapping from merge nodes to corresponding If node outputs. + std::unordered_map merge_to_replacement_; + FunctionLibraryDefinition* library_; Graph* graph_; diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index b0aabd63bbda784b3b7103a438ce025eea0cd93b..05fa1ee92dc172bd11cec9f99e3884996e00791f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -101,6 +101,17 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { TF_EXPECT_OK(t.status()); } +TEST_F(FunctionalizeCondTest, JoinCondStatesMergeWithInputNotInCondContext) { + Tensor val_tensor(DT_INT32, TensorShape()); + val_tensor.flat().setZero(); + Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); + Node* m = test::graph::Merge(graph_.get(), val, val); + + StateMap::CondState cond_state; + auto joined_or = JoinCondStatesMerge(m, /*src=*/nullptr, &cond_state); + EXPECT_FALSE(joined_or.ok()); +} + } // namespace } // namespace functionalize_cond } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4535ece374ceb801e450af98a21d5a4c5e8f2a29 --- /dev/null +++ b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace { +Status GetTestDevice(Session* session, string* test_device) { + std::vector devices; + TF_RETURN_IF_ERROR(session->ListDevices(&devices)); + + bool found_cpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) { + return device.device_type() == "CPU"; + }); + + bool found_gpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) { + return device.device_type() == "GPU"; + }); + + if (!found_gpu && !found_cpu) { + return errors::Internal("Expected at least one CPU or GPU!"); + } + + *test_device = found_gpu ? "GPU" : "CPU"; + VLOG(2) << "Using test device " << *test_device; + return Status::OK(); +} + +void FillZeros(Tensor* tensor) { + auto flat = tensor->flat(); + for (int i = 0; i < flat.size(); i++) { + flat.data()[i] = 0.0f; + } +} + +// This tests check that the implementation outputs from FusedBatchnorm +// training, reserve_space_{1|2}, are what we assume them to be in the TF/XLA +// lowering. +// +// If this test starts failing then it doesn't indicate that TF/cudnn have +// violated their contract, but it indicates that we need to update the TF/XLA +// lowering for FusedBatchnorm training to match the new implementation defined +// behavior. +TEST(FusedBatchnormReserveSpaceTest, Test) { + using ::tensorflow::ops::Const; + using ::tensorflow::ops::FusedBatchNorm; + + std::unique_ptr session( + tensorflow::NewSession(tensorflow::SessionOptions{})); + + string test_device; + TF_ASSERT_OK(GetTestDevice(session.get(), &test_device)); + + Scope root = tensorflow::Scope::NewRootScope(); + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + + Tensor scale_data(DT_FLOAT, TensorShape({10})); + FillZeros(&scale_data); + Output scale = + Const(root.WithOpName("scale"), Input::Initializer(scale_data)); + + Tensor offset_data(DT_FLOAT, TensorShape({10})); + FillZeros(&offset_data); + Output offset = + Const(root.WithOpName("offset"), Input::Initializer(offset_data)); + + Tensor mean_data(DT_FLOAT, TensorShape({0})); + Output mean = Const(root.WithOpName("offset"), Input::Initializer(mean_data)); + + Tensor variance_data(DT_FLOAT, TensorShape({0})); + Output variance = + Const(root.WithOpName("variance"), Input::Initializer(variance_data)); + + string tf_device = absl::StrCat("/device:", test_device, ":0"); + string xla_device = absl::StrCat("/device:XLA_", test_device, ":0"); + + FusedBatchNorm fused_batch_norm_tf( + root.WithOpName("fused_batch_norm_tf").WithDevice(tf_device), input, + scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true)); + FusedBatchNorm fused_batch_norm_xla( + root.WithOpName("fused_batch_norm_xla").WithDevice(xla_device), input, + scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true)); + + tensorflow::GraphDef graph; + TF_ASSERT_OK(root.ToGraphDef(&graph)); + + TF_ASSERT_OK(session->Create(graph)); + + Tensor input_data(DT_FLOAT, TensorShape({10, 10, 10, 10})); + auto flat_input = input_data.flat(); + for (int i = 0; i < flat_input.size(); i++) { + flat_input.data()[i] = (i - 5) / 1000.0f; + } + + std::vector results; + TF_ASSERT_OK(session->Run({{"input", input_data}}, + {fused_batch_norm_tf.reserve_space_1.name(), + fused_batch_norm_xla.reserve_space_1.name(), + fused_batch_norm_tf.reserve_space_2.name(), + fused_batch_norm_xla.reserve_space_2.name()}, + {}, &results)); + + test::ExpectClose(results[0], results[1], /*atol=*/1e-4); + test::ExpectClose(results[2], results[3], /*atol=*/1e-4); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index efb75749722893100494e089c0beb96944e9f1d4..5e4699bbb6218089d2e76a36c7351bf7fbd23264 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -88,6 +89,9 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, case XlaExpression::Kind::kResource: return errors::Unimplemented( "Resource as function argument is not yet implemented."); + case XlaExpression::Kind::kTensorList: + return errors::Unimplemented( + "TensorList as function argument is not yet implemented."); case XlaExpression::Kind::kInvalid: return errors::InvalidArgument("Invalid function argument"); } @@ -191,6 +195,9 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, // into the functions. XlaOpKernelContext xla_op_context(op_context); + XlaContext& context = XlaContext::Get(op_context); + auto* b = context.builder(); + XlaCompiler* compiler = xla_op_context.compiler(); NameAttrList func; @@ -219,8 +226,12 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RETURN_IF_ERROR( PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + bool add_token_input_output = + HasNodeAttr(n->def(), kXlaTokenInputNodesAttrName); + XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = false; + compile_options.add_token_input_output = add_token_input_output; XlaCompiler::CompilationResult result; TF_RETURN_IF_ERROR( compiler->CompileFunction(compile_options, func, arguments, &result)); @@ -234,9 +245,19 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, } handles.push_back(expressions[i]->handle()); } - - XlaContext& context = XlaContext::Get(op_context); - auto* b = context.builder(); + if (add_token_input_output) { + std::vector token_input_nodes; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->def(), kXlaTokenInputNodesAttrName, &token_input_nodes)); + std::vector token_inputs; + for (const string& node_name : token_input_nodes) { + auto token_or = compiler->GetNodeToken(node_name); + TF_RETURN_IF_ERROR(token_or.status()); + token_inputs.push_back(token_or.ConsumeValueOrDie()); + } + xla::XlaOp token_input = xla::AfterAll(b, token_inputs); + handles.push_back(token_input); + } auto output_handle = xla::Call(b, *result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so @@ -251,6 +272,10 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, ++computation_output; } } + if (add_token_input_output) { + TF_RETURN_IF_ERROR(compiler->SetNodeToken( + n->name(), xla::GetTupleElement(output_handle, computation_output))); + } return b->first_error(); } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d85b4f5ae0cb9c7d2476158a5830f921742ae980..69353fe87d833fba2c8766ed185481f2238a190d 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -1,16 +1,11 @@ +load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_kernel_library") + licenses(["notice"]) # Apache 2.0 package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) -load("//tensorflow:tensorflow.bzl", "tf_copts") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load( - "//third_party/mkl:build_defs.bzl", - "if_mkl", -) - tf_kernel_library( name = "xla_ops", srcs = [ @@ -39,6 +34,7 @@ tf_kernel_library( "dynamic_slice_ops.cc", "dynamic_stitch_op.cc", "elu_op.cc", + "empty_op.cc", "extract_image_patches_op.cc", "fake_param_op.cc", "fake_quantize_ops.cc", @@ -106,6 +102,7 @@ tf_kernel_library( "variable_ops.cc", "xla_broadcast_helper_op.cc", "xla_conv_op.cc", + "xla_dequantize_op.cc", "xla_dot_op.cc", "xla_pad_op.cc", "xla_reduce_op.cc", @@ -121,15 +118,10 @@ tf_kernel_library( ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:broadcast", - "//tensorflow/compiler/tf2xla/lib:cholesky", - "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", - "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal", @@ -142,19 +134,38 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/client/lib:quantize", "//tensorflow/compiler/xla/client/lib:sorting", + "//tensorflow/compiler/xla/client/lib:triangular_solve", + "//tensorflow/core:bitwise_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", + "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:linalg_ops_op_lib", + "//tensorflow/core:list_ops_op_lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:random_ops_op_lib", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:sparse_ops_op_lib", "//tensorflow/core:spectral_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:stateless_random_ops_op_lib", + "//tensorflow/core:training_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:concat_lib", "//tensorflow/core/kernels:constant_op", @@ -196,7 +207,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -216,7 +226,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:framework", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:conv_ops", diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 795ea09831e183a26fb3498b9bbaf9c3adaef9ed..5554d7a377d38554058aa731770ee10e400bc535 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -53,7 +53,11 @@ class XlaArgOp : public XlaOpKernel { const XlaExpression& arg = ctx->xla_context()->args()[index_]; OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid, errors::InvalidArgument("Invalid/missing argument expression")); - ctx->SetOutputExpression(0, arg); + if (ctx->expected_output_dtype(0) == DT_VARIANT) { + ctx->SetTensorListOutput(0, arg.handle()); + } else { + ctx->SetOutputExpression(0, arg); + } } private: @@ -63,6 +67,8 @@ class XlaArgOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp); }; -REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp); +REGISTER_XLA_OP( + Name("_Arg").AllowResourceTypes().AllowVariantTypes().CompilationOnly(), + XlaArgOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 4cfe946b2e6146f034867c06e996ffae42b90705..1b254e328a8c71bd81a0ec700e2af1d81a5fa67a 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" namespace tensorflow { namespace { @@ -28,9 +30,11 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = BatchDot(ctx->Input(0), ctx->Input(1), - /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, - /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); + auto result = + xla::BatchDot(MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(0), adj_x_), adj_x_), + MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(1), adj_y_), adj_y_)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 0e2f335f3354e3ae6008bdc0ac0b80683fe479c1..f1d78c87527eb5f818dcf92209feabe33653a625 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" @@ -34,6 +36,7 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); + is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; } void Compile(XlaOpKernelContext* ctx) override { @@ -71,7 +74,18 @@ class FusedBatchNormOp : public XlaOpKernel { // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. ctx->SetOutput(3, xla::GetTupleElement(output, 1)); - ctx->SetOutput(4, xla::GetTupleElement(output, 2)); + if (is_on_gpu_) { + // The last two outputs from the FusedBatchNorm training TensorFlow GPU + // op are implementation defined. For now we rely on the in-practice + // behavior of the op: + // output 3 is the mean + // output 4 is rsqrt(variance + epsilon) + xla::XlaOp variance = xla::GetTupleElement(output, 2); + ctx->SetOutput(4, xla::Rsqrt(xla::Add( + variance, xla::ScalarLike(variance, epsilon_)))); + } else { + ctx->SetOutput(4, xla::GetTupleElement(output, 2)); + } } else { xla::XlaOp output = xla::BatchNormInference( input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), @@ -89,6 +103,7 @@ class FusedBatchNormOp : public XlaOpKernel { float epsilon_; TensorFormat data_format_; bool is_training_; + bool is_on_gpu_; }; REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); @@ -104,6 +119,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); + is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; } void Compile(XlaOpKernelContext* ctx) override { @@ -130,6 +146,22 @@ class FusedBatchNormGradOp : public XlaOpKernel { xla::XlaOp scale_backprop; xla::XlaOp offset_backprop; if (is_training_) { + if (is_on_gpu_) { + // The last two inputs to the FusedBatchNormGrad training TensorFlow GPU + // op are implementation defined. For now we rely on the in-practice + // behavior of the op: input 3 is the mean input 4 is rsqrt(variance + + // epsilon) + // + // The XLA op expects: + // input 3 is the mean + // input 4 is the variance + // + // so we adjust input 4 here. + xla::XlaOp one = xla::ScalarLike(var, 1.0f); + xla::XlaOp epsilon = xla::ScalarLike(var, epsilon_); + var = xla::Sub(one / (var * var), epsilon); + } + xla::XlaOp output = xla::BatchNormGrad(activations, scale, mean, var, grad_backprop, epsilon_, feature_index); @@ -158,9 +190,8 @@ class FusedBatchNormGradOp : public XlaOpKernel { offset_backprop = XlaHelpers::ConvertElementType(reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) - auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); - auto scratch1 = xla::Pow( - xla::Add(var, xla::ConstantR0(b, epsilon_)), neg_half); + auto epsilon = XlaHelpers::FloatLiteral(b, scale_dtype, epsilon_); + auto scratch1 = xla::Rsqrt(xla::Add(var, epsilon)); // scratch2 = sum(y_backprop * (x - mean)) auto mul = @@ -187,6 +218,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { TensorFormat data_format_; float epsilon_; bool is_training_; + bool is_on_gpu_; }; REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 46e5d68c78fd9ff26a88dc2a1484c3a67b76f4f3..6b675fa8a94e0bc932baaa359565cbc8e4614ee5 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -39,7 +39,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, OP_REQUIRES( ctx, - xla::ShapeUtil::Rank(crops.shape()) == 2 && + crops.shape().rank() == 2 && block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) && 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1), errors::InvalidArgument("crops should have shape [", block_rank, diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index e7f369b761f36a717ea5fb536780af91a8955b1e..33bdf9aec3167b0277f3c1db18c9e247ed9bb5d1 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -48,8 +48,11 @@ class BiasOp : public XlaOpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias_shape), errors::InvalidArgument("Biases must be 1D: ", bias_shape.DebugString())); - int feature_dim = (data_format_ == FORMAT_NHWC) ? input_shape.dims() - 1 - : input_shape.dims() - 3; + + // feature_dim is the channel (C) dimension of the data. + int feature_dim = (data_format_ == FORMAT_NHWC) + ? input_shape.dims() - 1 + : /*data_format == FORMAT_NCHW*/ 1; OP_REQUIRES( ctx, feature_dim >= 0, errors::InvalidArgument("Input tensor does not have enough dimensions " @@ -91,9 +94,10 @@ class BiasAddGradOp : public XlaOpKernel { errors::InvalidArgument("Input tensor must be at least 2D: ", out_backprop_shape.DebugString())); + // feature_dim is the channel (C) dimension of the data. int feature_dim = (data_format_ == FORMAT_NHWC) ? out_backprop_shape.dims() - 1 - : out_backprop_shape.dims() - 3; + : /*data_format == FORMAT_NCHW*/ 1; OP_REQUIRES( ctx, feature_dim >= 0, errors::InvalidArgument("Input tensor does not have enough dimensions " diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 5e9280c1fe692037b0a842a92ef5a8c28b854a54..ad6b334326a470442c8c0d79b725345d4165be10 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -20,7 +20,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -165,12 +167,8 @@ XLA_MAKE_BINARY( xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), lhs, extend_dimensions)); -static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) { - return xla::Mul(x, x); -} - XLA_MAKE_BINARY(SquaredDifference, - Square(b, xla::Sub(lhs, rhs, extend_dimensions))); + xla::Square(xla::Sub(lhs, rhs, extend_dimensions))); XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions)); @@ -195,8 +193,8 @@ XLA_MAKE_BINARY(SoftplusGrad, // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2 XLA_MAKE_BINARY(SoftsignGrad, xla::Div(lhs, - Square(b, xla::Add(XlaHelpers::One(b, input_type(0)), - xla::Abs(rhs))))); + xla::Square(xla::Add(XlaHelpers::One(b, input_type(0)), + xla::Abs(rhs))))); XLA_MAKE_BINARY(TanhGrad, xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)), @@ -204,6 +202,8 @@ XLA_MAKE_BINARY(TanhGrad, XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(NextAfter, xla::NextAfter(lhs, rhs)); + #undef XLA_MAKE_BINARY class ApproximateEqualOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 8cc2479dd555380da7500abe6b2aca380110333b..ca2152d6c103e05c06809d85d9529720ff112217 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -19,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -31,6 +33,7 @@ class CastOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_)); OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_)); OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_)); } void Compile(XlaOpKernelContext* ctx) override { @@ -48,6 +51,36 @@ class CastOp : public XlaOpKernel { // imaginary part. output = xla::ConvertElementType(xla::Real(input), dst_type_); } else { + if (use_truncation_) { + OP_REQUIRES( + ctx, + xla::primitive_util::IsFloatingPointType(src_type_) && + xla::primitive_util::IsFloatingPointType(dst_type_), + errors::Unimplemented("Truncate attribute is only " + "implemented for floating point datatypes.")); + int mantissa_difference = + xla::primitive_util::SignificandWidth(src_type_) - + xla::primitive_util::SignificandWidth(dst_type_); + OP_REQUIRES(ctx, mantissa_difference > 0, + errors::Unimplemented( + "Truncate attribute is only implemented in cases where " + "dst datatype " + "has fewer mantissa bits than the src datatype")); + int src_bitwidth = xla::primitive_util::BitWidth(src_type_); + + // Bitcast to same-width integer, mask off the LSBs, bitcast back to the + // source datatype. + int64 mask = ~((1L << mantissa_difference) - 1); + xla::PrimitiveType same_width_int = + xla::primitive_util::UnsignedIntegralTypeForBitWidth(src_bitwidth); + OP_REQUIRES(ctx, same_width_int != xla::PRIMITIVE_TYPE_INVALID, + errors::Unimplemented("Unexpected type bitwidth")); + input = xla::BitcastConvertType( + xla::And( + xla::BitcastConvertType(input, same_width_int), + ::tensorflow::IntegerLiteral(builder, same_width_int, mask)), + src_type_); + } output = xla::ConvertElementType(input, dst_type_); } @@ -57,6 +90,7 @@ class CastOp : public XlaOpKernel { protected: DataType src_dtype_, dst_dtype_; xla::PrimitiveType src_type_, dst_type_; + bool use_truncation_; TF_DISALLOW_COPY_AND_ASSIGN(CastOp); }; @@ -79,8 +113,8 @@ class BitcastOp : public XlaOpKernel { if (src_dtype_ == dst_dtype_) { output = input; } else { - // The only complex type in XLA is C64, so error out if the bitcast has a - // complex source or destination type and the bitcast is not trivial. + // Error out if the bitcast has a complex source or destination type and + // the bitcast is not trivial. OP_REQUIRES(ctx, !xla::primitive_util::IsComplexType(src_type_) && !xla::primitive_util::IsComplexType(dst_type_), diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 7199b9b6feb36dd45ef51f4c38463bc715fcc38a..c2b4c28d1566f5429c5d8109db94af0c3762b131 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -99,8 +99,8 @@ class CategoricalOp : public XlaOpKernel { xla::PrimitiveType xla_output_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_type(0), &xla_output_type)); - xla::XlaOp argmax = XlaHelpers::ArgMax(softmax_entries, xla_output_type, - /*axis=*/class_dimension); + xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type, + /*axis=*/class_dimension); if (num_samples == 1) { argmax = xla::Reshape(argmax, {batch_size, 1}); } diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 9fcbc86adc0967cbb7fb73da8bdabc58b60953da..0ed3044efa5b1060d2b0ad2d5563b0e02ebf66ec 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" namespace tensorflow { namespace { @@ -24,7 +24,7 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - ctx->SetOutput(0, Cholesky(ctx->Input(0))); + ctx->SetOutput(0, xla::Cholesky(ctx->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index cd7c7f4a82df7a65829787efcb1fd2f77870e945..91e4d9cea7cbf6075e30250587044174c4b8e7f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -24,13 +24,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index dff8af800229b9605bb93e0498bc5e5cf012f244..ff6c54e47c62f0555ef045e25051f6ec5a3c1d39 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -83,6 +83,17 @@ class ConstOp : public XlaOpKernel { return; } break; + case DT_COMPLEX128: + if (proto_.scomplex_val_size() == 2) { + ctx->SetOutput( + 0, + xla::Broadcast(xla::ConstantR0( + b, xla::complex128(proto_.dcomplex_val(0), + proto_.dcomplex_val(1))), + shape.dim_sizes())); + return; + } + break; case DT_INT32: if (proto_.int_val_size() == 1) { ctx->SetOutput( diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 641fefafb357f6ad10483c454600f3dadd4f8cb7..5f99b24e221ba6c926032ef7a1b4bf1e92df7a68 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -26,13 +26,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" @@ -212,8 +212,8 @@ Status ConvBackpropComputeDimensionsV2XlaShapes( XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); return ConvBackpropComputeDimensionsV2( label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, - out_backprop_tensor_shape, dilations, strides, padding, data_format, - dims); + out_backprop_tensor_shape, dilations, strides, padding, + /*explicit_paddings=*/{}, data_format, dims); } } // anonymous namespace @@ -227,6 +227,11 @@ xla::StatusOr ConvOpAttrs::Create(int num_spatial_dims, TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); + // TODO(reedwm): Support explicit padding. + if (attrs.padding == EXPLICIT) { + return errors::Unimplemented( + "XLA does not yet support Conv2D with explicit padding."); + } string data_format; TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); @@ -392,23 +397,31 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( builder->GetShape(activations)); TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, builder->GetShape(gradients)); + xla::XlaOp filter_backprop; + + xla::Shape input_shape = activations_shape; + xla::Shape output_shape = out_backprop_shape; + + TensorShape input_tensor_shape, filter_tensor_shape, output_tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(output_shape, &output_tensor_shape)); + const xla::Shape expanded_filter_shape = attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) : filter_shape; - // Reuse dimension computation logic from conv_grad_ops.cc. ConvBackpropDimensions dims; - TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( - type_string, attrs.num_spatial_dims, activations_shape, - expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, - attrs.padding, attrs.data_format, &dims)); - // The filter gradients are computed by a convolution of the input // activations and the output gradients, with some appropriate padding. // See the comment at the top of conv_grad_ops.h for details. - xla::ConvolutionDimensionNumbers dnums; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, activations_shape, + expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, + attrs.padding, attrs.data_format, &dims)); + // The activations (inputs) form the LHS of the convolution. // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] // For the gradient computation, we flip the roles of the batch and @@ -420,6 +433,14 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + bool use_batch_group_count = + filter_tensor_shape.dim_size(num_dims - 1) == 1 && attrs.depthwise; + + std::vector> padding(attrs.num_spatial_dims); + std::vector rhs_dilation(attrs.num_spatial_dims); + std::vector window_strides(attrs.num_spatial_dims); + std::vector ones(attrs.num_spatial_dims, 1); + // Swap n_dim and c_dim in the activations. dnums.set_input_batch_dimension(c_dim); dnums.set_input_feature_dimension(n_dim); @@ -430,19 +451,21 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( dnums.set_kernel_input_feature_dimension(n_dim); dnums.set_kernel_output_feature_dimension(c_dim); - std::vector> padding(attrs.num_spatial_dims); - std::vector rhs_dilation(attrs.num_spatial_dims); - std::vector window_strides(attrs.num_spatial_dims); - std::vector ones(attrs.num_spatial_dims, 1); + // The dimension swap below is needed because filter shape is KH,KW,F,DM. + if (use_batch_group_count) { + dnums.set_output_batch_dimension(attrs.num_spatial_dims + 1); + dnums.set_output_feature_dimension(attrs.num_spatial_dims); + } else { + dnums.set_output_batch_dimension(attrs.num_spatial_dims); + dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); + } // Tensorflow filter shape is [ H, W, ..., inC, outC ]. for (int i = 0; i < attrs.num_spatial_dims; ++i) { dnums.add_output_spatial_dimensions(i); } - dnums.set_output_batch_dimension(attrs.num_spatial_dims); - dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); - for (int i = 0; i < attrs.num_spatial_dims; ++i) { + for (int64 i = 0; i < attrs.num_spatial_dims; ++i) { int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); @@ -451,7 +474,7 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // convolution, we get the right size for the filter. // The padded_in_rows should be such that when we convolve this with the // expanded_out_rows as a filter, we should get filter_rows back. - // + const int64 padded_in_size = dims.spatial_dims[i].expanded_output_size + (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim]; @@ -496,11 +519,14 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // // This is done by specifying the window dilation factors in the // convolution HLO below. - auto filter_backprop = - xla::ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); - if (attrs.depthwise) { + filter_backprop = xla::ConvGeneralDilated( + activations, gradients, window_strides, padding, /*lhs_dilation=*/ones, + rhs_dilation, dnums, + /*feature_group_count=*/1, + /*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1); + + if (!use_batch_group_count && attrs.depthwise) { filter_backprop = ContractFilterForDepthwiseBackprop( filter_shape, filter_backprop, activations.builder()); } diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index d820528a43064e327cb90e5a2889f77ab1f3f3e2..52c3c2c4a903a8c51f6b511774bc0312d39df826 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -22,16 +22,16 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 49c12fc232092873b69961644a059abc6035f64f..ee79cbc70da269be7586c47b4fd33c901f4fd581 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 6e6ba21daf5bf3eab5bfc15378e77b6dd253da7c..b119997cf39e210ed8e0ae730a08829e72b238b4 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/empty_op.cc b/tensorflow/compiler/tf2xla/kernels/empty_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..00d2ce7c12fdc96483612059d1c792c847df04f3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/empty_op.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Empty Op. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class EmptyOp : public XlaOpKernel { + public: + explicit EmptyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("init", &init_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + // The output of this Op is a tensor of shape 'shape' with each + // element set to the default value of 'dtype'. If 'init' is false then + // the result values may be left undefined, though we don't do that here. + const TensorShape shape_shape = ctx->InputShape("shape"); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(shape_shape), + errors::InvalidArgument("shape must be a vector of int32, got shape ", + shape_shape.DebugString())); + + std::vector shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("shape", &shape)); + + auto default_value = xla::Zero(ctx->builder(), type_); + auto result = xla::Broadcast(default_value, shape); + ctx->SetOutput(0, result); + } + + private: + DataType dtype_; + xla::PrimitiveType type_; + bool init_; +}; + +REGISTER_XLA_OP(Name("Empty").CompileTimeConstantInput("shape"), EmptyOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 6df8b5367d2390e65995beb1583b225755e6ee9f..a623585aad3b1b8f1f096ca527e7694d74f1ba46 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -21,12 +21,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/util/padding.h" diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 20b0de193dc060197f3062d3be0b8d45f7dcb9b1..6472045265e4d930a5da770a68f5c502192201ae 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -168,13 +167,13 @@ class GatherOp : public XlaOpKernel { OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); const auto params_dims = input_shape.dims(); - if (axis < 0) { - axis += params_dims; - } OP_REQUIRES( - context, 0 <= axis && axis < params_dims, + context, -params_dims <= axis && axis < params_dims, errors::InvalidArgument("Expected axis in the range [", -params_dims, ", ", params_dims, "), but got ", axis)); + if (axis < 0) { + axis += params_dims; + } } DataType index_type = input_type(1); diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 19dd38c46ef154ea74bcbb6721dd04924702efcc..8b27e8e85a37bd5aa757b0cdd7e00e9fa3c0cf6e 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -38,9 +38,13 @@ class IdentityOp : public XlaOpKernel { // XLA_* devices also register a "real" Identity operator so we suppress the // dummy operator using CompilationOnly(). -REGISTER_XLA_OP(Name("Identity").AllowResourceTypes().CompilationOnly(), - IdentityOp); -REGISTER_XLA_OP(Name("IdentityN").AllowResourceTypes().CompilationOnly(), +REGISTER_XLA_OP( + Name("Identity").AllowResourceTypes().AllowVariantTypes().CompilationOnly(), + IdentityOp); +REGISTER_XLA_OP(Name("IdentityN") + .AllowResourceTypes() + .AllowVariantTypes() + .CompilationOnly(), IdentityOp); REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index b5e083912555c865b5eadc7697075c9ca4451ca9..aa5637e2669555da17af8bb05ab08beeba6a89c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -56,6 +56,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; std::vector arguments(input_types_.size()); + int num_resource_args = 0; for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; DataType type = ctx->input_type(i + 1); @@ -79,14 +80,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.name = resource->name(); VLOG(2) << "Resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString() + << " shape: " << arg.HumanString() << " initialized: " << arg.initialized; + + num_resource_args++; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; arg.shape = ctx->InputShape(i + 1); VLOG(2) << "Arg type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString(); + << " shape: " << arg.HumanString(); } } @@ -147,12 +150,12 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape then_input_shape = then_result.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(then_input_shape), + OP_REQUIRES(ctx, then_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape else_input_shape = else_result.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(else_input_shape), + OP_REQUIRES(ctx, else_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); OP_REQUIRES(ctx, xla::ShapeUtil::Compatible(then_input_shape, else_input_shape), @@ -236,12 +239,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { ctx->SetOutput(i, output_handle); } if (has_token_input_output_) { - // Set token output for this "if" op. + // Set token output for this "If" op. Token output is the last output of + // XLA computation, which comes after all "normal" TF outputs and resource + // updates. For "If" node, num of resource updates equals to number of + // resource args because we set `return_updated_values_for_all_resources` + // to true in XlaCompiler option. xla::XlaOp token_output = - xla::GetTupleElement(outputs, output_types_.size()); + xla::GetTupleElement(outputs, output_types_.size() + num_resource_args); auto shape_or = b->GetShape(token_output); OP_REQUIRES_OK(ctx, shape_or.status()); - OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(), errors::FailedPrecondition( "Token output is not token type: ", xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index e9bb0a77e99d144863b027bd214081316d61c314..92b20fe0ba5611ca5314cd954026f7b71ea75f84 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -13,18 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -185,19 +187,20 @@ class AdjustContrastOpV2 : public XlaOpKernel { factor_shape.DebugString())); xla::XlaBuilder* b = context->builder(); - xla::XlaOp input = context->Input(0); - xla::XlaOp factor = context->Input(1); - DataType type = context->input_type(0); + xla::XlaOp input = context->Input(0); + xla::XlaOp factor = XlaHelpers::ConvertElementType(context->Input(1), type); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *context->GetOrCreateAdd(accumulation_type), {height_dim, width_dim}); - auto output = XlaHelpers::ConvertElementType(reduce, type); - output = - xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); + + auto output = xla::Div( + reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width)); + output = XlaHelpers::ConvertElementType(output, type); std::vector broadcast_dims(input_shape.dims() - 2); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); @@ -233,8 +236,10 @@ class AdjustSaturationOp : public XlaOpKernel { channels, " channels.")); xla::XlaBuilder* b = context->builder(); - xla::XlaOp input = context->Input(0); - xla::XlaOp scale = context->Input(1); + xla::XlaOp input = + XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT); + xla::XlaOp scale = + XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT); DataType type = context->input_type(0); @@ -249,15 +254,17 @@ class AdjustSaturationOp : public XlaOpKernel { /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); - auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), - channel_shape); + auto hsv = + RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape); - hsv[1] = xla::Clamp(XlaHelpers::Zero(b, type), xla::Mul(hsv[1], scale), - XlaHelpers::One(b, type)); + hsv[1] = xla::Clamp(XlaHelpers::Zero(b, DT_FLOAT), xla::Mul(hsv[1], scale), + XlaHelpers::One(b, DT_FLOAT)); - auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); + auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT); - context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); + auto output = XlaHelpers::ConvertElementType( + xla::ConcatInDim(b, rgb, channel_dim), type); + context->SetOutput(0, output); } }; REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp); @@ -283,8 +290,10 @@ class AdjustHueOp : public XlaOpKernel { channels, " channels.")); xla::XlaBuilder* b = context->builder(); - xla::XlaOp input = context->Input(0); - xla::XlaOp delta = context->Input(1); + xla::XlaOp input = + XlaHelpers::ConvertElementType(context->Input(0), DT_FLOAT); + xla::XlaOp delta = + XlaHelpers::ConvertElementType(context->Input(1), DT_FLOAT); DataType type = context->input_type(0); @@ -299,20 +308,22 @@ class AdjustHueOp : public XlaOpKernel { /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); - auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), - channel_shape); + auto hsv = + RGBToHSV(context, b, {red, green, blue}, DT_FLOAT, channel_shape); - auto zero = XlaHelpers::Zero(b, type); - auto one = XlaHelpers::One(b, type); + auto zero = XlaHelpers::Zero(b, DT_FLOAT); + auto one = XlaHelpers::One(b, DT_FLOAT); auto& hue = hsv[0]; hue = xla::Rem(xla::Add(hsv[0], delta), one); hue = xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue); - auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); + auto rgb = HSVToRGB(context->builder(), hsv, DT_FLOAT); - context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); + auto output = XlaHelpers::ConvertElementType( + xla::ConcatInDim(b, rgb, channel_dim), type); + context->SetOutput(0, output); } }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); @@ -351,24 +362,26 @@ struct SuppressBodyFn { auto num_outputs_so_far = values[1]; auto iou_mask = values[2]; auto included_iou = values[3]; - auto zero_r1 = xla::ConstantR1(builder, {0}); + auto zero = xla::ConstantR0(builder, 0); // Determine if current elem is active using a slice. - auto row_idx_r1 = xla::Reshape(row_idx, {1}); - auto active_elem = xla::DynamicSlice(included_iou, row_idx_r1, {1}); + // TODO(b/118437727): The only reason we need an explicit vector is because + // some old GCCs can't deduce the right type for MakeConstSpan, and + // providing a single-value initializer list directly uses the wrong + // overload. Delete this once the deprecated overload is gone. + std::vector row_idx_vector = {row_idx}; + auto active_elem = xla::DynamicSlice(included_iou, row_idx_vector, {1}); active_elem = xla::Reshape(active_elem, {}); // Increment output count iff current elem is not suppressed. num_outputs_so_far = xla::Select( active_elem, num_outputs_so_far + xla::ConstantR0(builder, 1), num_outputs_so_far); // Slice out the row_idx. - auto starts = xla::ConcatInDim(builder, {row_idx_r1, zero_r1}, 0); - auto row_iou = xla::DynamicSlice(iou_mask, starts, {1, num_boxes}); + auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes}); // Remove the diagonal from consideration. An elem cannot suppress // itself. - auto update_starts = xla::ConcatInDim(builder, {zero_r1, row_idx_r1}, 0); row_iou = xla::DynamicUpdateSlice( row_iou, xla::ConstantR2FromArray2D(builder, {{false}}), - update_starts); + {zero, row_idx}); // Create a suppression by inverting polarity. row_iou = xla::Reshape(row_iou, {num_boxes}); auto supp_mask = xla::Not(row_iou); @@ -505,9 +518,9 @@ class NonMaxSuppressionOp : public XlaOpKernel { init_values.push_back(included_iou); auto suppress_loop_result = - XlaWhileLoop(WhileCondFn(num_boxes, output_size), - SuppressBodyFn(num_boxes), init_values, "suppress_loop", - builder) + xla::WhileLoopHelper(WhileCondFn(num_boxes, output_size), + SuppressBodyFn(num_boxes), init_values, + "suppress_loop", builder) .ValueOrDie(); xla::XlaOp included_score = diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 0c7ca602bfacd598dada0303d3a3e77fe7f1b0fc..b96d45316f626e678a64392a4315979eeeb6e83c 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -73,10 +72,10 @@ namespace { // from in_size to out_size. struct ResizeConvolutionDims { // Size of the kernel to use. - std::vector kernel_size; + std::vector kernel_size; // k // Stride of the convolution to use. - std::vector stride; + std::vector stride; // S }; ResizeConvolutionDims ComputeResizeConvolutionParameters( absl::Span in_size, absl::Span out_size, @@ -118,8 +117,10 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( // + dims.stride * (out_size - 1) int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, int64 stride) { - return (2 * kernel_size - 1) + (out_size - 1) * stride - (kernel_size - 1) - - 1 - (kernel_size * (in_size - 1)); + int64 padding = (2 * kernel_size - 1) + (out_size - 1) * stride - + (kernel_size - 1) - 1 - (kernel_size * (in_size - 1)); + + return padding; } // Form a 2D convolution kernel like: @@ -133,7 +134,7 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, // If the 2D kernel would be very large, the 1D kernel can be applied once in // each dimension due to the symmetry of the kernel along all axis to reduce the // computational intensity. -xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) { +xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, int64 n) { std::vector kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { float v = (i + 1.0f) / n; @@ -143,43 +144,64 @@ xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) { return xla::ConstantR1(builder, kernel); } +// Unlike the bilinear kernel, which is triangular, the nearest neighbor +// kernel is a square. For example, a 1D kernel with n=3 would look like +// [0 1 1 1 0] +// and n=4 would look like +// [0 0 1 1 1 1 0]. +// Note that in the second case, the kernel is not symmetric and we default +// to the right (because an existing non TPU kernel +// for nearest neighbor resize already chose to default to the right, +// so we want to be consistent). +xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, int64 n) { + std::vector kernel(n * 2 - 1, 0.0f); + std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f); + + return xla::ConstantR1(builder, kernel); +} + // Kernels with more than 16 spatial elements are considered intense and the -// kernel should applied to each dimension independently. +// kernel should be applied to each dimension independently. const int64 kMax2DKernelSize = 16; -xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, - absl::Span kernel_size, - int64 channels) { +xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder, + absl::Span kernel_size, + int64 channels, bool is_kernel_bilinear) { + auto make_kernel_func = + is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; + auto depthwise_kernel = xla::Broadcast( xla::Zero(builder, xla::F32), {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}); return xla::Mul( - xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]), + xla::Add(depthwise_kernel, make_kernel_func(builder, kernel_size[1]), /*broadcast_dimensions=*/{1}), - Make1DKernel(builder, kernel_size[0]), + make_kernel_func(builder, kernel_size[0]), /*broadcast_dimensions=*/{0}); } -xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, - absl::Span kernel_size, - int64 channels, int64 dim) { +xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder, + absl::Span kernel_size, + int64 channels, int64 dim, + bool is_kernel_bilinear) { + auto make_kernel_func = + is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; + auto depthwise_kernel = xla::Broadcast(xla::Zero(builder, xla::F32), {dim == 0 ? (2 * kernel_size[0] - 1) : 1, dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}); - return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]), + return xla::Add(depthwise_kernel, make_kernel_func(builder, kernel_size[dim]), /*broadcast_dimensions=*/{dim}); } -xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, - const xla::XlaOp& input, - const int num_spatial_dims, - std::vector in_size, - std::vector out_size, - const int64 channels, - const bool align_corners) { - // Picture for a 1x3 to 1x4 resize: +xla::XlaOp ResizeUsingDilationAndConvolution( + xla::XlaBuilder* builder, const xla::XlaOp& input, + const int num_spatial_dims, std::vector in_size, + std::vector out_size, const int64 channels, const bool align_corners, + bool is_kernel_bilinear) { + // Picture for a 1x3 to 1x4 bilinear resize: // stride = 2, kernel size = 3 // Input: // 3 6 9 @@ -265,8 +287,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, // Split convolutions into independent dimensions if they would be a very // large kernel. if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, dims.kernel_size, + channels, is_kernel_bilinear); output = xla::ConvGeneralDilated(input_data, kernel, dims.stride, /*padding=*/ @@ -276,8 +298,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); } else { - xla::XlaOp kernel0 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( + builder, dims.kernel_size, channels, 0, is_kernel_bilinear); output = xla::ConvGeneralDilated( input_data, kernel0, {dims.stride[0], 1}, /*padding=*/ @@ -285,8 +307,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*lhs_dilation=*/{dims.kernel_size[0], 1}, /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); - xla::XlaOp kernel1 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( + builder, dims.kernel_size, channels, 1, is_kernel_bilinear); output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ @@ -307,13 +329,11 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, return output; } -xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, - const xla::XlaOp& grad, - const int num_spatial_dims, - std::vector in_size, - std::vector grad_size, - const int64 channels, - const bool align_corners) { +xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( + xla::XlaBuilder* builder, const xla::XlaOp& grad, + const int num_spatial_dims, std::vector in_size, + std::vector grad_size, const int64 channels, + const bool align_corners, bool is_kernel_bilinear) { ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, grad_size, align_corners); @@ -333,8 +353,8 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); xla::XlaOp output; if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, dims.kernel_size, + channels, is_kernel_bilinear); // Broadcast the input kernel where the forward op expanded from a size == 1 // dimension to a size > 1 dimension. This has the effect of summing the @@ -356,14 +376,14 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); } else { - xla::XlaOp kernel0 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); - xla::XlaOp kernel1 = - MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); - - // Broadcast the input kernel where the forward op expanded from a size == 1 - // dimension to a size > 1 dimension. This has the effect of summing the - // gradient contributions in that dimension. + xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( + builder, dims.kernel_size, channels, 0, is_kernel_bilinear); + xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( + builder, dims.kernel_size, channels, 1, is_kernel_bilinear); + + // Broadcast the input kernel where the forward op expanded from a + // size == 1 dimension to a size > 1 dimension. This has the effect of + // summing the gradient contributions in that dimension. if (in_size[0] == 1 && grad_size[0] > 1) { kernel0 = xla::Add(kernel0, xla::ConstantR1(builder, grad_size[0], 0), @@ -408,109 +428,139 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, return output; } -class ResizeBilinearOp : public XlaOpKernel { - public: - explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); +void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, + bool is_kernel_bilinear) { + xla::XlaBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + // First dimension always assumed to be batch + const int64 batch = input_shape.dim_size(0); + std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + // Last/4th dimension always assumed to be num channels + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + std::vector out_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); + OP_REQUIRES(ctx, out_size.size() == 2, + errors::InvalidArgument("output size must be length 2, got ", + out_size.size())); + OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, + errors::InvalidArgument("output size must be positive, got [", + out_size[0], ",", out_size[1], "]")); + + const int num_spatial_dims = 2; + + xla::XlaOp input = ctx->Input(0); + + // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in + // dimension i. + bool slice_input = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + slice_input = true; + in_size[i] = 1; + } + } + if (slice_input) { + input = xla::Slice(input, {0, 0, 0, 0}, + {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); } - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); + // Output is always type float. + input = xla::ConvertElementType(input, xla::F32); + + // Special Case: + // Instead of doing a ResizeUsingDilationAndConvolution directly, + // while (out_size[0]-1) = c * 2^x * (in_size[0]-1) for x>1 c>1, resize the + // image to 2*(in_size[0]-1)+1 x-times and then resize by scale c(int here). + // Instead of resizing directly we resize it iteratively. + // + // Since bilinear resize can be broken down as 2 sequential linear + // operations along different dimensions. + // Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd. + // This does not work in the case of align_corners_=false because of special + // padding requirements that cause multiple resizes to be very different + // from a single resize. + // + // This makes the convolutions kernels smaller and the operation faster. + xla::XlaOp output = input; + while (in_size != out_size) { + if (in_size[0] != 1 && in_size[1] != 1) { + std::vector k = { + (static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2), + (static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; + if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && + k[0] > 1 && k[1] > 1 && align_corners_) { + std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, + (in_size[1] - 1) * 2 + 1}; + output = ResizeUsingDilationAndConvolution( + b, input, num_spatial_dims, in_size, next_out_size, channels, + align_corners_, is_kernel_bilinear); + input = output; + in_size = next_out_size; + } else { + output = ResizeUsingDilationAndConvolution( + b, input, num_spatial_dims, in_size, out_size, channels, + align_corners_, is_kernel_bilinear); + in_size = out_size; + } + } else { + output = ResizeUsingDilationAndConvolution( + b, input, num_spatial_dims, in_size, out_size, channels, + align_corners_, is_kernel_bilinear); + in_size = out_size; + } + } - TensorShape input_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, input_shape.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input_shape.DebugString())); - const int64 batch = input_shape.dim_size(0); - std::vector in_size = {input_shape.dim_size(1), - input_shape.dim_size(2)}; - const int64 channels = input_shape.dim_size(3); - OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, - errors::InvalidArgument("input size must be positive, got [", - in_size[0], ",", in_size[1], "]")); + ctx->SetOutput(0, output); +} - std::vector out_size; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); - OP_REQUIRES(ctx, out_size.size() == 2, - errors::InvalidArgument("output size must be length 2, got ", - out_size.size())); - OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, - errors::InvalidArgument("output size must be positive, got [", - out_size[0], ",", out_size[1], "]")); +class ResizeNearestNeighborOp : public XlaOpKernel { + public: + explicit ResizeNearestNeighborOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented("ResizeNearestNeighbor with align_corners=False " + "is not yet implemented")); + } - const int num_spatial_dims = 2; + void Compile(XlaOpKernelContext* ctx) override { + GeneralCompile(ctx, align_corners_, is_kernel_bilinear_); + } - xla::XlaOp input = ctx->Input(0); + private: + bool align_corners_ = true; + bool is_kernel_bilinear_ = false; +}; - // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in - // dimension i. - bool slice_input = false; - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] > 1 && out_size[i] == 1) { - // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first - // entry before resizing. - slice_input = true; - in_size[i] = 1; - } - } - if (slice_input) { - input = - xla::Slice(input, {0, 0, 0, 0}, - {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); - } +REGISTER_XLA_OP(Name("ResizeNearestNeighbor").CompileTimeConstantInput("size"), + ResizeNearestNeighborOp); - // Output is always type float. - input = xla::ConvertElementType(input, xla::F32); - - // Special Case: - // Instead of doing a ResizeUsingDilationAndConvolution directly, - // while (out_size[0]-1) = c * 2^x * (in_size[0]-1) for x>1 c>1, resize the - // image to 2*(in_size[0]-1)+1 x-times and then resize by scale c(int here). - // Instead of resizing directly we resize it iteratively. - // - // Since bilinear resize can be broken down as 2 sequential linear - // operations along different dimensions. - // Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd. - // This does not work in the case of align_corners_=false because of special - // padding requirements that cause multiple resizes to be very different - // from a single resize. - // - // This makes the convolutions kernels smaller and the operation faster. - xla::XlaOp output = input; - while (in_size != out_size) { - if (in_size[0] != 1 && in_size[1] != 1) { - std::vector k = { - (static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2), - (static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; - if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && - k[0] > 1 && k[1] > 1 && align_corners_) { - std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, - (in_size[1] - 1) * 2 + 1}; - output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, next_out_size, - channels, align_corners_); - input = output; - in_size = next_out_size; - } else { - output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, out_size, - channels, align_corners_); - in_size = out_size; - } - } else { - output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, out_size, channels, - align_corners_); - in_size = out_size; - } - } +class ResizeBilinearOp : public XlaOpKernel { + public: + explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + } - ctx->SetOutput(0, output); + void Compile(XlaOpKernelContext* ctx) override { + GeneralCompile(ctx, align_corners_, is_kernel_bilinear_); } private: - bool align_corners_; + bool align_corners_ = true; + bool is_kernel_bilinear_ = true; }; REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstantInput("size"), @@ -582,19 +632,19 @@ class ResizeBilinearGradOp : public XlaOpKernel { (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolutionGradOp( b, grad, num_spatial_dims, in_size, next_grad_size, channels, - align_corners_); + align_corners_, true); grad = output; in_size = next_grad_size; } else { output = ResizeUsingDilationAndConvolutionGradOp( b, grad, num_spatial_dims, in_size, grad_size, channels, - align_corners_); + align_corners_, true); in_size = grad_size; } } else { output = ResizeUsingDilationAndConvolutionGradOp( b, grad, num_spatial_dims, in_size, grad_size, channels, - align_corners_); + align_corners_, true); in_size = grad_size; } } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 843b6bb4e658af16fd753c1a20b35dd3d18df027..c1539f48d4f729510b2d930de91666a7c31f1ef0 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -18,17 +18,16 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/index_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min) @@ -66,9 +65,9 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp input = ctx->Input(0); xla::XlaOp output; if (is_min_) { - output = XlaHelpers::ArgMin(input, index_xla_type, axis); + output = xla::ArgMin(input, index_xla_type, axis); } else { - output = XlaHelpers::ArgMax(input, index_xla_type, axis); + output = xla::ArgMax(input, index_xla_type, axis); } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index e2c05b648bb194b1b452c527ddb1a2c5995b1217..e4bbdef6480104a1051acfc647644deb65c80171 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -16,16 +16,16 @@ limitations under the License. // Native XLA implementations of indexing ops. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { @@ -74,7 +74,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // shape isn't supported. if (!ctx->compiler()->options().allow_cpu_custom_calls || (input_dims != 1 && input_dims != 2)) { - xla::XlaOp output = XlaHelpers::ArgMax(ctx->Input(0), output_type, axis); + xla::XlaOp output = xla::ArgMax(ctx->Input(0), output_type, axis); ctx->SetOutput(0, output); return; } @@ -110,8 +110,8 @@ class ArgMaxCustomCallOp : public XlaOpKernel { auto shape_status = b.GetShape(arg); OP_REQUIRES_OK(ctx, shape_status.status()); xla::Shape arg_shape = shape_status.ConsumeValueOrDie(); - *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout( - xla::ShapeUtil::Rank(arg_shape)); + *arg_shape.mutable_layout() = + xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank()); arg_shapes.push_back(std::move(arg_shape)); } diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 6440770c29894c951f010f6c1deb929f4fe79bbf..f36e0025250b3a196b31755a1ddf6620c415b6a3 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -24,8 +24,8 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array kMatmulTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128}}; class MatMulOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index f4def11d08c31513aec5aad15187016a7294c2fd..90c0ebefb24ec2c4378782e9b15d3f57c33032a4 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" namespace tensorflow { namespace { @@ -29,7 +29,7 @@ class MatrixTriangularSolveOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = TriangularSolve( + auto result = xla::TriangularSolve( ctx->Input(0), ctx->Input(1), /*left_side=*/true, /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); ctx->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index f6b8534f4d7c537e5b708ee000e00cb92123584b..656f9b898f32dfc05215014f51c2bbaf07580836 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -38,8 +38,7 @@ class MirrorPadOp : public XlaOpKernel { // - [1, 2, 3, 3, 2] in symmetric mode. int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0; xla::XlaOp accum = t; - for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; - --dimno) { + for (int64 dimno = original_shape.rank() - 1; dimno >= 0; --dimno) { auto t_rev = xla::Rev(accum, {dimno}); int64 lhs_padding = pad_literal.Get({dimno, 0}); int64 rhs_padding = pad_literal.Get({dimno, 1}); diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index a9b519d8928cc2807831fd6b4f12e60b7d58ea55..426a0941df57f19072d1cb9f3fa3d0079db465c5 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -24,12 +24,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index a259da6383d461fd11b0d79096bf66aae7ddef06..23bb050a34d9246cdf73090aa6adfca054bf8bcf 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -26,10 +26,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/pooling_ops_common.h" @@ -152,7 +152,12 @@ class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, - /*reduction_type=*/ctx->input_type(0)) {} + /*reduction_type=*/ctx->input_type(0)) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -180,10 +185,6 @@ class MaxPool2DOp : public MaxPoolOp { public: explicit MaxPool2DOp(OpKernelConstruction* ctx) : MaxPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); @@ -204,7 +205,12 @@ class AvgPoolOp : public PoolingOp { AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ - XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + XlaHelpers::SumAccumulationType(ctx->input_type(0))) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -241,10 +247,6 @@ class AvgPool2DOp : public AvgPoolOp { public: explicit AvgPool2DOp(OpKernelConstruction* ctx) : AvgPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); @@ -390,6 +392,11 @@ class AvgPoolGradOp : public XlaOpKernel { OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1, errors::Unimplemented( "Pooling is not yet supported on the batch dimension.")); + + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); } int num_dims() const { return num_spatial_dims_ + 2; } @@ -449,10 +456,6 @@ class AvgPool2DGradOp : public AvgPoolGradOp { public: explicit AvgPool2DGradOp(OpKernelConstruction* ctx) : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) { - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP( diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc index 7ea0afc1f53cbe4cfcc3f6121a4ecd55864c1b52..66ec40a946b8a063d84acd33daf81f52ea2c35ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/qr.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" namespace tensorflow { namespace { @@ -26,7 +26,7 @@ class QROp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_)); } void Compile(XlaOpKernelContext* ctx) override { - auto result = QRDecomposition(ctx->Input(0), full_matrices_); + auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_); if (!result.ok()) { ctx->SetStatus(result.status()); return; diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 8822e29f7e77b1cbc6fa6ca61d0062d9b1b0c36e..01b047f732f0e9fb3b45b272e7886e2f8cf4fff4 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -20,12 +20,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -160,23 +160,30 @@ class RandomShuffleOp : public XlaOpKernel { -> xla::StatusOr> { auto swaps = loop_vars[0]; auto indices = loop_vars[1]; - i = xla::Reshape(i, {1}); + // TODO(b/118437727): The absl::Span nonsense is only necessary because + // the deprecated overload creates ambiguity for the single-element span + // case. Remove it once the deprecated overload is gone. // temp = indices[i] - auto temp = xla::DynamicSlice(indices, i, {1}); + auto temp = + xla::DynamicSlice(indices, absl::Span({i}), {1}); // swap_index = swaps[i] - auto swap_index = xla::DynamicSlice(swaps, i, {1}); + auto swap_index = xla::Reshape( + xla::DynamicSlice(swaps, absl::Span({i}), {1}), {}); // swap_value = indices[swaps[i]] - auto swap_value = xla::DynamicSlice(indices, swap_index, {1}); + auto swap_value = xla::DynamicSlice( + indices, absl::Span({swap_index}), {1}); // indices[i] = indices[swaps[i]] - indices = xla::DynamicUpdateSlice(indices, swap_value, i); + indices = xla::DynamicUpdateSlice(indices, swap_value, + absl::Span({i})); // indices[swaps[i]] = temp - indices = xla::DynamicUpdateSlice(indices, temp, swap_index); + indices = xla::DynamicUpdateSlice( + indices, temp, absl::Span({swap_index})); return std::vector{swaps, indices}; }; // for i in range(n): auto swap_loop_result = - XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, - "indices_swap_loop", builder) + xla::ForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) .ValueOrDie(); auto swapped_indices = swap_loop_result[1]; diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index 769e0cd1409dd7e8099178c8d80b5a9adb0b20b3..f9985d526033ca675c701a508a3d1576e46bc5f7 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -126,7 +125,7 @@ XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices, dimensions.back() = 1; auto batch_indices = - xla::Iota(b, xla::ShapeUtil::MakeShape(xla::U32, dimensions), + xla::Iota(b, xla::ShapeUtil::MakeShape(xla::S32, dimensions), /*iota_dimension=*/0); return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1); @@ -190,11 +189,53 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, scatter_dim_numbers); } +// Bounds samples to 0 if the warp image indices are out of the (-1, image_size) +// bound. +// The resulting dimension is given by 'result_dims'. +XlaOp BoundSamples(XlaOpKernelContext* ctx, XlaOp warp, + xla::PrimitiveType warp_type, TensorShape warp_shape, + std::vector result_dims, + std::vector broadcasted_dims, int64 last_warp_dim, + xla::Shape data_shape, XlaOp sample) { + auto is_gt_minus_one = + xla::Gt(warp, + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {-1, -1}), warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + auto is_lt_image_size = xla::Lt( + warp, + xla::ConvertElementType( + xla::ConstantR1( + ctx->builder(), + {/*width=*/static_cast(data_shape.dimensions(2)), + /*height=*/static_cast(data_shape.dimensions(1))}), + warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + + auto is_in_bound_padded_x_y = xla::And(is_gt_minus_one, is_lt_image_size); + // Reduce along last dimension. The resulting dimension is: + // [batch, dim_0, ...dim_n]. + auto is_in_bound = xla::Reduce( + is_in_bound_padded_x_y, xla::ConstantR0(ctx->builder(), true), + xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, ctx->builder()), + {last_warp_dim}); + + // Broadcast 'is_in_bound' to the same dimension as 'result_dims'. + auto broadcasted_is_in_bound = + xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims); + + // Set out of bound samples to zero. + auto zeros = + xla::Broadcast(xla::Zero(ctx->builder(), warp_type), result_dims); + return xla::Select(broadcasted_is_in_bound, sample, zeros); +} + // Build computation the backprop into input 'data'. // Where input: // grad_output is of dimension [batch, dim_0, ...dim_n, channel] // ratio is of dimension [batch, dim_0, ...dim_n, 2] // gather_indices is of dimension [batch, dim_0, ...dim_n, 3] +// data_shape is of dimension [batch, x(width), y(height), channel] // // Output: // scatter-add to each 2x2 grad_data neighbor: @@ -202,10 +243,12 @@ XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, // grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy // grad_data[fx, cy, chan] += output_grad * dx * (1 - dy) // grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy) -// where (dx, dy) is (1 - ratio). +// where (dx, dy) is (1 - ratio). If (dx, dy) is out of bound, then the their +// contribution is 0 to 'grad_data'. XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, - XlaOp gather_indices, xla::PrimitiveType warp_type, - TensorShape warp_shape, int64 data_channels, + XlaOp gather_indices, XlaOp warp, + xla::PrimitiveType warp_type, TensorShape warp_shape, + int64 last_warp_dim, int64 data_channels, xla::Shape data_shape) { // Weights tensor has dimension [batch, dim_0, ... dim_n, 4]. auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type); @@ -230,6 +273,18 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(), 0); + // Set out of bound weights to 0. + // The dimension of the reshaped_weight: [batch, dim_0, ...dim_n, 2, 2]. + std::vector reshaped_result_dims(warp_dims.begin(), + warp_dims.end() - 1); + reshaped_result_dims.push_back(2); + reshaped_result_dims.push_back(2); + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + reshaped_weights = BoundSamples(ctx, warp, warp_type, warp_shape, + reshaped_result_dims, broadcasted_dims, + last_warp_dim, data_shape, reshaped_weights); + // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel]. auto broadcast_reshaped_weights = xla::BroadcastInDim( reshaped_weights, weights_with_channels_dims, reshaped_weights_indices); @@ -246,18 +301,41 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, auto grad_data = xla::ConstantLiteral( ctx->builder(), xla::Literal::CreateFromShape(data_shape)); - return ScatterToGradData(ctx, grad_data, gather_indices, - grad_output_multiply_weights, warp_shape.dims(), - warp_type); + // Pad grad data then slice it back. + // + // After left and right column 0-padding, the new dimension of padded data + // will be [batch, x+2, y+2, channel]. + auto padded_grad_data = + xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type), + xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); + + auto shifting_value = xla::ConstantR1( + ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); + auto shifted_gather_indices = + xla::Add(gather_indices, shifting_value, {last_warp_dim}); + + auto updated_grad_data = ScatterToGradData( + ctx, padded_grad_data, shifted_gather_indices, + grad_output_multiply_weights, warp_shape.dims(), warp_type); + + const int64 batch_size = data_shape.dimensions(0); + const int64 width = data_shape.dimensions(1); + const int64 height = data_shape.dimensions(2); + // Slice out the result accounting for the padding. + return xla::Slice( + updated_grad_data, /*start_indices=*/{0, 1, 1, 0}, + /*limit_indices=*/{batch_size, width + 1, height + 1, data_channels}, + /*strides=*/{1, 1, 1, 1}); } // Build computation for the backprop into input 'warp'. // Where input: -// warp is of dimension [batch, dim_0, ...dim_n, 2] -// grad_output is of dimension [batch, dim_0, ...dim_n, channel] -// ratio is of dimension [batch, dim_0, ...dim_n, 2] -// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] -// data is of dimension [batch, x, y, channel] +// warp is of dimension [batch, dim_0, ...dim_n, 2] +// grad_output is of dimension [batch, dim_0, ...dim_n, channel] +// ratio is of dimension [batch, dim_0, ...dim_n, 2] +// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] where the last +// dimension of size 3 is for {batch, x(width), y(height)}. +// data is of dimension [batch, x, y, channel] // // Output (simplified by ignoring the batch dimensions): // Since the forward path has: @@ -276,12 +354,12 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, // grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy) // grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy) // -// where (px, py) is warp, (fx, fy) is the left top corner and (cx, cy) is the +// where (px, py) is warp, (fx, fy) is the top left corner and (cx, cy) is the // bottom right corner in a 2x2 neighborhood. XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, XlaOp gather_indices, XlaOp data, TensorShape warp_shape, int64 data_channels, - xla::PrimitiveType data_type) { + xla::PrimitiveType data_type, xla::Shape data_shape) { auto warp_dims = warp_shape.dim_sizes(); std::vector warp_dims_without_last_dims(warp_dims.begin(), warp_dims.end() - 1); @@ -290,12 +368,30 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; neighbor_broadcast_dims.push_back(4); - // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] - auto neighbors_data = Gather2by2Neighbors( - ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); + // With dimension [batch, dim_0, ...dim_n, 4] + auto neighbor_broadcast_shape = + xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); const int64 last_warp_dim = warp_shape.dims() - 1; + // Pad data with 0, before gathering such that 0 will be returned for samples + // in the range of (-1, 0) or (image_dimension-1, image_dimension). + // After left and right column 0-padding, the new dimension of padded data + // will be [batch, x+2, y+2, channel]. + auto padded_data = + xla::Pad(data, xla::Zero(ctx->builder(), data_type), + xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); + + auto shifting_value = xla::ConstantR1( + ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); + auto shifted_gather_indices = + xla::Add(gather_indices, shifting_value, {last_warp_dim}); + + // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] + auto neighbors_data = + Gather2by2Neighbors(ctx->builder(), padded_data, shifted_gather_indices, + data_channels, warp_shape.dims()); + // Since we will be creating the dot product of: // lhs: [batch, dim_0, ...dim_n, 4] // and @@ -418,7 +514,7 @@ class ResamplerOp : public XlaOpKernel { // Find the coordinates of the top left corner for the 2x2 region to be // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the // last dimension of size 2 in turn is [x, y]. - XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + XlaOp top_left = xla::ConvertElementType(warp, xla::S32); auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); @@ -527,7 +623,8 @@ class ResamplerGradOp : public XlaOpKernel { size, "]")); } // Last dimension of warp shape must be of size 2. - OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, + const int64 last_warp_dim = warp_shape.dims() - 1; + OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, errors::InvalidArgument( "the last dimension of warp must be exactly size 2.")); xla::PrimitiveType warp_type = ctx->input_xla_type(1); @@ -550,24 +647,32 @@ class ResamplerGradOp : public XlaOpKernel { // Find the top left corner coordinate for the region to be sampled from. // The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension // of size 2 in turn is [x, y]. - XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + XlaOp top_left = xla::ConvertElementType(xla::Floor(warp), xla::S32); - // Dimensions are [batch, dim_0, ... dim_n, 2] + // Dimensions are [batch, dim_0, ... dim_n, 2]. XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type); // Indices for gathering neighboring pixels. auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); - auto grad_data = - CalculateGradData(ctx, grad_output, ratio, gather_indices, warp_type, - warp_shape, data_channels, data_shape); + auto grad_data = CalculateGradData( + ctx, grad_output, ratio, gather_indices, warp, warp_type, warp_shape, + last_warp_dim, data_channels, data_shape); auto grad_warp = CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data, - warp_shape, data_channels, data_type); + warp_shape, data_channels, data_type, data_shape); + auto warp_dims = warp_shape.dim_sizes(); + std::vector result_dims(warp_dims.begin(), warp_dims.end() - 1); + result_dims.push_back(2); + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + auto grad_warp_bounded = + BoundSamples(ctx, warp, warp_type, warp_shape, result_dims, + broadcasted_dims, last_warp_dim, data_shape, grad_warp); ctx->SetOutput(0, grad_data); - ctx->SetOutput(1, grad_warp); + ctx->SetOutput(1, grad_warp_bounded); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index e4046c795577983bff1a8053743bf4d3a258e583..1f417037284c87753b219ea5ce1d4edce0ce6336 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -37,10 +37,14 @@ class RetvalOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const Tensor& input = ctx->op_kernel_context()->input(0); - OP_REQUIRES(ctx, input.dtype() == dtype_, - errors::InvalidArgument( - "Type mismatch: actual ", DataTypeString(input.dtype()), - " vs. expect ", DataTypeString(dtype_))); + // DT_VARIANT types represent Tensor Lists and are wrapped in a DT_UINT8 + // tensor so we skip the check here. + if (dtype_ != DT_VARIANT) { + OP_REQUIRES(ctx, input.dtype() == dtype_, + errors::InvalidArgument( + "Type mismatch: actual ", DataTypeString(input.dtype()), + " vs. expect ", DataTypeString(dtype_))); + } auto frame = ctx->call_frame(); if (frame) { // If 'frame' is non-null, this is an inner function call inside a JIT @@ -59,8 +63,9 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(), - RetvalOp); +REGISTER_XLA_OP( + Name("_Retval").AllowResourceTypes().AllowVariantTypes().CompilationOnly(), + RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 4b9e1a578be2445091228953df7e5c5e82b42c28..daefdfc58a4957d9e685d25aa90da6218f2041ad 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -23,13 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 9e4c57c9bf73369662274f6b783418e18ff860c2..aaf8c6075dd292e33e70683774a6c1bf374183e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index b1fa2915d59e4e5e2f2523e20e9a37898d087117..7a620d2a6518f8686ef570b33aac971d1dccb6c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -157,9 +157,11 @@ class LinSpaceOp : public XlaOpKernel { flat(0) = start; } else { const float step = (stop - start) / (num - 1); - for (int64 i = 0; i < num; ++i) { + for (int64 i = 0; i < num - 1; ++i) { flat(i) = start + step * i; } + // The last value in the sequence must be equal to stop. + flat(num - 1) = stop; } break; } @@ -171,9 +173,11 @@ class LinSpaceOp : public XlaOpKernel { flat(0) = start; } else { const double step = (stop - start) / (num - 1); - for (int64 i = 0; i < num; ++i) { + for (int64 i = 0; i < num - 1; ++i) { flat(i) = start + step * i; } + // The last value in the sequence must be equal to stop. + flat(num - 1) = stop; } break; } diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 12830816ec16c9797f0fe4d8f3f13f5a8176161d..31d4cc131600f360c764ffa02831046c85d846e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -20,10 +20,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { @@ -91,14 +92,20 @@ class SizeOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); - const int64 size = input_shape.num_elements(); - OP_REQUIRES(ctx, FastBoundsCheck(size, std::numeric_limits::max()), + OP_REQUIRES(ctx, + FastBoundsCheck(input_shape.num_elements(), + std::numeric_limits::max()), errors::InvalidArgument("Size does not work for tensors > " "int32 max.")); Tensor size_constant(DT_INT32, TensorShape({})); - size_constant.scalar()() = static_cast(size); - - ctx->SetConstantOutput(0, size_constant); + const int rank = input_shape.dims(); + xla::XlaBuilder* builder = ctx->builder(); + auto size = xla::One(builder, xla::U32); + for (int64 i = 0; i < rank; ++i) { + size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i)); + } + size = xla::ConvertElementType(size, xla::S32); + ctx->SetOutput(0, size); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc index 76ea5f525598f511f295eb5a30f3cf603fbf57aa..b18e3f965c427aec456ce2b188dad79485df23cc 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/framework/bounds_check.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 622efac81766fc3ddaf538b58170f34fce06927a..52bed2670b4b8408e3b2f72b64bf370aea5325f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -39,7 +39,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, OP_REQUIRES( ctx, - xla::ShapeUtil::Rank(paddings.shape()) == 2 && + paddings.shape().rank() == 2 && block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) && 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1), errors::InvalidArgument("paddings should have shape [", block_rank, diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 8e9e4daf99d3dd3b8e149e3f3e5f6c27665c0fcb..b6c96b1f582710e1cc39e6e1e0e800ef8170743d 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -24,13 +24,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -45,7 +45,7 @@ Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource, return shape_or_status.status(); } xla::Shape shape = shape_or_status.ValueOrDie(); - TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + TF_RET_CHECK(shape.IsTuple()); return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), stack_shape); } @@ -146,9 +146,9 @@ class StackPushOp : public XlaOpKernel { xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); @@ -202,9 +202,9 @@ class StackPopOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); + std::vector start_indices(stack_shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; auto slice_shape = stack_shape.dim_sizes(); slice_shape[0] = 1LL; diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 5db52781be473a9a1aef0adf105e3edf69ccd306..50653d7b3973b73d580cdeec5d71943b575d7cc9 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 10d990b3213ab882cf44a4df20a977633de3fdab..e8846fbe88fa2a75244398ef0f601fd74e80ec50 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -288,19 +288,21 @@ class StridedSliceAssignOp : public XlaOpKernel { xla::XlaOp rhs = ctx->Input(4); absl::InlinedVector dimensions_to_reverse; - absl::InlinedVector slice_begin, slice_dims; + absl::InlinedVector slice_begin; + absl::InlinedVector slice_dims; for (int i = 0; i < begin.size(); ++i) { - // TODO(phawkins): implement strides != 1 + // TODO(b/121179231): implement strides != 1 OP_REQUIRES( ctx, strides[i] == 1 || strides[i] == -1, errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); if (strides[i] > 0) { - slice_begin.push_back(begin[i]); + slice_begin.push_back(xla::ConstantR0(ctx->builder(), begin[i])); slice_dims.push_back(end[i] - begin[i]); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. - slice_begin.push_back(end[i] + 1); + slice_begin.push_back( + xla::ConstantR0(ctx->builder(), end[i] + 1)); slice_dims.push_back(begin[i] - end[i]); dimensions_to_reverse.push_back(i); } @@ -311,14 +313,7 @@ class StridedSliceAssignOp : public XlaOpKernel { } rhs = xla::Reshape(rhs, slice_dims); - if (lhs_shape.dims() == 0) { - // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix - // and remove this workaround. - lhs = rhs; - } else { - lhs = xla::DynamicUpdateSlice( - lhs, rhs, xla::ConstantR1(ctx->builder(), slice_begin)); - } + lhs = xla::DynamicUpdateSlice(lhs, rhs, slice_begin); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 939d7e19515a1cb41e3e23e9d1fa957ae09ecab7..77a3e5c001e1c715f23ae5148f94dae2faa81acf 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -27,13 +27,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -123,7 +123,8 @@ Status GetTensorArrayShape(const XlaResource* resource, xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, absl::Span update_dims, - const xla::XlaOp& start_indices, DataType dtype) { + absl::Span start_indices, + DataType dtype) { xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); xla::XlaOp sum = dtype == DT_BOOL ? xla::Or(current, update) : xla::Add(current, update); @@ -212,9 +213,9 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::XlaOp flow = ctx->Input(3); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); @@ -263,9 +264,9 @@ class TensorArrayReadOp : public XlaOpKernel { xla::XlaOp index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); + std::vector start_indices(ta_shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; @@ -419,10 +420,10 @@ class TensorArrayScatterOp : public XlaOpKernel { auto slice = xla::Slice(value, value_starts, value_ends, value_strides); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = xla::Slice(indices, {i}, {i + 1}, {1}); - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + auto index = xla::Reshape(xla::Slice(indices, {i}, {i + 1}, {1}), {}); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_); } } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 64a24703ae1460abfedb6d9298e1e164076a199a..65020012283d9c5f62e5e2fd11fc2bf1110e019a 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ // XLA TensorList operators. +// Tensor lists are represented as tuple consisting of a pre-allocated list +// consisting of the tensors (and where dim 0 is the list index), along with a +// scalar telling us the current number of elements. #include #include @@ -24,13 +27,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -45,11 +48,27 @@ Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op, return shape_or_status.status(); } xla::Shape shape = shape_or_status.ValueOrDie(); - TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + TF_RET_CHECK(shape.IsTuple()); return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), tensor_list_shape); } +class TensorListLengthOp : public XlaOpKernel { + public: + explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp tl = ctx->Input(0); + xla::XlaOp index = xla::GetTupleElement(tl, 1); + ctx->SetOutput(0, index); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp); +}; + +REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp); + class TensorListReserveOp : public XlaOpKernel { public: explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -67,9 +86,10 @@ class TensorListReserveOp : public XlaOpKernel { tensor_shape.AppendShape(element_shape); xla::XlaBuilder* b = ctx->builder(); - ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), - tensor_shape.dim_sizes()), - xla::ConstantR0(b, 0)})); + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), + tensor_shape.dim_sizes()), + xla::ConstantR0(b, num_elements)})); } private: @@ -85,19 +105,41 @@ REGISTER_XLA_OP(Name("TensorListReserve") class EmptyTensorListOp : public XlaOpKernel { public: - explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { - ctx->CtxFailure( + TensorShape element_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); + int64 max_num_elements; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); + OP_REQUIRES( + ctx, max_num_elements >= 0, errors::InvalidArgument("XLA compilation requires a fixed tensor list " - "size. Use TensorListReserve instead.")); + "size. Set the max number of elements.")); + + TensorShape tensor_shape; + tensor_shape.AddDim(max_num_elements); + tensor_shape.AppendShape(element_shape); + + xla::XlaBuilder* b = ctx->builder(); + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), + tensor_shape.dim_sizes()), + xla::ConstantR0(b, 0)})); } private: + DataType dtype_; + TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp); }; -REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp); +REGISTER_XLA_OP(Name("EmptyTensorList") + .CompileTimeConstantInput("element_shape") + .CompileTimeConstantInput("max_num_elements"), + EmptyTensorListOp); class TensorListElementShapeOp : public XlaOpKernel { public: @@ -139,6 +181,136 @@ class TensorListElementShapeOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp); +class TensorListGetItemOp : public XlaOpKernel { + public: + explicit TensorListGetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp state = ctx->Input(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); + + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = ctx->Input(1); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + std::vector start_indices(shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; + auto slice_shape = shape.dim_sizes(); + slice_shape[0] = 1LL; + + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); + // Remove the leading '1' dimension. + std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); + + ctx->SetOutput(0, xla::Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListGetItemOp); +}; + +REGISTER_XLA_OP(Name("TensorListGetItem"), TensorListGetItemOp); + +class TensorListStackOp : public XlaOpKernel { + public: + explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp state = ctx->Input(0); + xla::XlaOp ta = xla::GetTupleElement(state, 0); + ctx->SetOutput(0, ta); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListStackOp); +}; + +REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp); + +class TensorListFromTensorOp : public XlaOpKernel { + public: + explicit TensorListFromTensorOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape element_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &element_shape)); + + const TensorShape tensor_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, tensor_shape.dims() > 0, + errors::InvalidArgument("Input value must be at least a " + "vector but received shape: ", + tensor_shape.DebugString())); + const int num_elements = tensor_shape.dim_size(0); + + xla::XlaBuilder* b = ctx->builder(); + const xla::XlaOp tensor = ctx->Input(0); + + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {tensor, xla::ConstantR0(b, num_elements)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListFromTensorOp); +}; + +REGISTER_XLA_OP( + Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"), + TensorListFromTensorOp); + +class TensorListSetItemOp : public XlaOpKernel { + public: + explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp tl = ctx->Input(0); + TensorShape elem_shape = ctx->InputShape(2); + + xla::XlaOp ta = xla::GetTupleElement(tl, 0); + xla::XlaOp index = ctx->Input(1); + xla::XlaOp value = ctx->Input(2); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); + + ctx->SetTensorListOutput( + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), + index + xla::ConstantR0(b, 1)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListSetItemOp); +}; + +REGISTER_XLA_OP(Name("TensorListSetItem"), TensorListSetItemOp); + class TensorListPushBackOp : public XlaOpKernel { public: explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -147,25 +319,23 @@ class TensorListPushBackOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp list = ctx->Input(0); + xla::XlaOp tl = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(1); - xla::XlaOp ta = xla::GetTupleElement(list, 0); - xla::XlaOp index = xla::GetTupleElement(list, 1); + xla::XlaOp ta = xla::GetTupleElement(tl, 0); + xla::XlaOp index = xla::GetTupleElement(tl, 1); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + std::vector start_indices(elem_shape.dims() + 1, + xla::ConstantR0(b, 0)); + start_indices[0] = index; TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); auto update = xla::Reshape(value, slice_shape.dim_sizes()); - // TODO(phawkins): We don't check the index is in bounds --- there is no - // error mechanism in XLA. - ctx->SetOutput( + ctx->SetTensorListOutput( 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), index + xla::ConstantR0(b, 1)})); } @@ -197,20 +367,17 @@ class TensorListPopBackOp : public XlaOpKernel { index = index - xla::ConstantR0(b, 1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = - xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), - xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}})); - + std::vector start_indices(shape.dims(), + xla::ConstantR0(b, 0)); + start_indices[0] = index; auto slice_shape = shape.dim_sizes(); slice_shape[0] = 1LL; - // TODO(phawkins): We don't check the index is in bounds --- there is no - // error mechanism in XLA. xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, xla::Tuple(b, {ta, index})); + ctx->SetTensorListOutput(0, xla::Tuple(b, {ta, index})); ctx->SetOutput(1, xla::Reshape(read, value_shape)); } diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 8a0c94cfae1b298bd62a3231caf39ecf9b32880e..ee3bdf3394e37c757f31724e73e95417becaa534 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 960c1462ceb8c00a2d6c96564f6c985fd1caef0f..26d4214099d1d07c1b2e275d783654d9cd948e28 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -172,6 +172,65 @@ class ResourceApplyMomentum : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes), ResourceApplyMomentum); +class ResourceApplyKerasMomentum : public XlaOpKernel { + public: + explicit ResourceApplyKerasMomentum(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(2); + + TensorShape var_shape, accum_shape; + xla::XlaOp var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + TensorShape momentum_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), + errors::InvalidArgument("momentum is not a scalar: ", + momentum_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp momentum = ctx->Input(4); + + accum = accum * momentum - grad * lr; + if (use_nesterov_) { + // See https://github.com/tensorflow/tensorflow/pull/2798 for an + // explanation of the reparameterization used here. + var = var + accum * momentum - grad * lr; + } else { + var = var + accum; + } + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); + } + + private: + bool use_nesterov_; +}; +REGISTER_XLA_OP( + Name("ResourceApplyKerasMomentum").TypeConstraint("T", kFloatTypes), + ResourceApplyKerasMomentum); + class ResourceApplyAdagrad : public XlaOpKernel { public: explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index c9b324a243e4cc3ec64daa3ca0d285336a0d0154..76793d677ba45f8e863e684a149da684c8ce8787 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index a0ea6422d732b00fc1b8cf855d9c9ad603b87c82..4544e03491438d5f21cf986bc952572bd19d548c 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -65,11 +65,8 @@ XLAJIT_MAKE_UNARY(Exp, xla::Exp(x)); XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x)); XLAJIT_MAKE_UNARY(Floor, xla::Floor(x)); XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x)); -XLAJIT_MAKE_UNARY( - IsInf, - xla::Eq(xla::Abs(x), - xla::ScalarLike(x, std::numeric_limits::infinity()))); -XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x)); +XLAJIT_MAKE_UNARY(IsInf, xla::IsInf(x)); +XLAJIT_MAKE_UNARY(IsNan, xla::IsNan(x)); // Return 1/x XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x); XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x); @@ -116,37 +113,11 @@ XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x)); XLAJIT_MAKE_UNARY(Real, xla::Real(x)); XLAJIT_MAKE_UNARY(Imag, xla::Imag(x)); +XLAJIT_MAKE_UNARY(Erf, xla::Erf(x)); +XLAJIT_MAKE_UNARY(Erfc, xla::Erfc(x)); #undef XLAJIT_MAKE_UNARY -// Erf/Erfc. For x in (-1, 1), the erf approximation is used; erfc polynomial -// is used outside of this range. -class ErfOp : public XlaOpKernel { - public: - explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp x = ctx->Input(0); - xla::XlaOp one = xla::ScalarLike(x, 1.0); - auto y = - xla::Select(xla::Gt(xla::Abs(x), one), one - xla::Erfc(x), xla::Erf(x)); - ctx->SetOutput(0, y); - } -}; -REGISTER_XLA_OP(Name("Erf"), ErfOp); - -class ErfcOp : public XlaOpKernel { - public: - explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp x = ctx->Input(0); - xla::XlaOp one = xla::ScalarLike(x, 1.0); - auto y = - xla::Select(xla::Lt(xla::Abs(x), one), one - xla::Erf(x), xla::Erfc(x)); - ctx->SetOutput(0, y); - } -}; -REGISTER_XLA_OP(Name("Erfc"), ErfcOp); - class LgammaOp : public XlaOpKernel { public: explicit LgammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index 8671632976023fded04c26a9780c1a67638b0916..2fc5619de737b8977e4249e4d2297a0303c339ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -24,12 +24,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 2c92a585f5679242d672d0402e617ff199b94f17..dfa09b16081e93ba843a1858e68e6ff756de20c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -291,5 +291,19 @@ class ResourceScatterNdAddOp : public ResourceScatterOp { }; REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp); +class ResourceScatterNdSubOp : public ResourceScatterOp { + public: + explicit ResourceScatterNdSubOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/true, + /*combiner=*/Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Sub(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterNdSub"), ResourceScatterNdSubOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index ce007fc04a818869686b9936a1607cee42665e87..f49da9683b3622bdda708cc305306baafa1639df 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -41,8 +41,7 @@ Status MakeXlaCompilerArgumentsFromInputs( *has_uninitialized_vars = false; *has_tensor_arrays = false; for (int i = 0; i < ctx->num_inputs(); ++i) { - VLOG(2) << " Input " << i - << " type: " << DataTypeString(ctx->input_type(i)) + VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i)) << " shape: " << ctx->InputShape(i).DebugString(); XlaCompiler::Argument& arg = (*args)[i]; DataType type = ctx->input_type(i); @@ -71,13 +70,20 @@ Status MakeXlaCompilerArgumentsFromInputs( arg.name = resource->name(); VLOG(2) << " resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString() + << " shape: " << arg.ShapeHumanString() << " initialized: " << arg.initialized; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = ctx->input_type(i); - arg.shape = ctx->InputShape(i); + + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp handle = ctx->Input(i); + auto shape_or_status = builder->GetShape(handle); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + arg.shape = shape_or_status.ValueOrDie(); } } return Status::OK(); @@ -207,12 +213,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape body_input_shape = body.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(body_input_shape), + OP_REQUIRES(ctx, body_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); xla::Shape cond_input_shape = cond.xla_input_shapes[0]; - OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(cond_input_shape), + OP_REQUIRES(ctx, cond_input_shape.IsTuple(), errors::FailedPrecondition("Expected tuple shape")); VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) @@ -233,13 +239,22 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::ShapeUtil::HumanString(body_input_shape), " vs. ", xla::ShapeUtil::HumanString(body.xla_output_shape))); - xla::Shape expected_cond_output_shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::PRED, {})}); + xla::Shape expected_cond_output_shape_without_side_effect = + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::PRED, {})}); + xla::Shape expected_cond_output_shape_with_side_effect = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::PRED, {}), + xla::ShapeUtil::MakeTokenShape()}); OP_REQUIRES(ctx, - xla::ShapeUtil::Compatible(cond.xla_output_shape, - expected_cond_output_shape), + xla::ShapeUtil::Compatible( + cond.xla_output_shape, + expected_cond_output_shape_without_side_effect) || + xla::ShapeUtil::Compatible( + cond.xla_output_shape, + expected_cond_output_shape_with_side_effect), errors::InvalidArgument( - "Output shape of loop condition should be (pred[]), got: ", + "Output shape of loop condition should be (pred[]) or " + "(pred[], token[]), got: ", xla::ShapeUtil::HumanString(cond.xla_output_shape))); int num_inputs = body.input_mapping.size(); @@ -283,11 +298,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); - // Sets non-variable outputs. + // Sets non-variable outputs and determine when resource variables start. + int resource_index = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { ctx->SetOutput(body.input_mapping[i], xla::GetTupleElement(while_result, i)); + ++resource_index; + } else { + break; } } if (has_token_input_output_) { @@ -296,7 +315,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::GetTupleElement(while_result, ctx->num_outputs()); auto shape_or = builder->GetShape(token_output); OP_REQUIRES_OK(ctx, shape_or.status()); - OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(), errors::FailedPrecondition( "Token output is not token type: ", xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); @@ -309,7 +328,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); if (update.modified) { - int pos = body.outputs.size() + i; + int pos = resource_index + i; OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, @@ -329,8 +348,11 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building while loop"; } -REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp); -REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp); -REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); +REGISTER_XLA_OP(Name("While").AllowResourceTypes().AllowVariantTypes(), + XlaWhileOp); +REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes().AllowVariantTypes(), + XlaWhileOp); +REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes().AllowVariantTypes(), + XlaWhileOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index 4612f19971a3ce6994aef303f751748b77ccda9a..b20adc592a0d3d2129c897218ddbfc891b4cd40a 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -78,7 +78,7 @@ class XlaConvOp : public XlaOpKernel { xla::XlaOp output = xla::ConvGeneralDilated( context->Input(0), context->Input(1), window_strides, padding, lhs_dilation, rhs_dilation, dnums_, feature_group_count, - &precision_config_); + /*batch_group_count=*/1, &precision_config_); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a30b4861f6b3a964c0c874a3affab7d6198264d7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/quantize.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaDequantizeOp : public XlaOpKernel { + public: + explicit XlaDequantizeOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("min_range", &min_range_)); + OP_REQUIRES_OK(context, context->GetAttr("max_range", &max_range_)); + OP_REQUIRES_OK(context, context->GetAttr("mode", &mode_)); + OP_REQUIRES_OK(context, + context->GetAttr("transpose_output", &transpose_output_)); + } + + void Compile(XlaOpKernelContext* context) override { + const xla::XlaOp& input = context->Input(0); + + xla::QuantizedRange range(min_range_, max_range_); + + xla::XlaOp output = + xla::Dequantize(input, range, mode_, transpose_output_); + context->SetOutput(0, output); + } + + private: + float min_range_; + float max_range_; + bool transpose_output_; + string mode_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaDequantizeOp); +}; + +REGISTER_XLA_OP(Name("XlaDequantize"), XlaDequantizeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 1ce3930fd1cd91f8e8dfb765b49be2dc969d1bd7..3d7b0bc959f9dbf3c1b9749379e2ea0d285b302b 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -15,22 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") - -cc_library( - name = "batch_dot", - srcs = ["batch_dot.cc"], - hdrs = ["batch_dot.h"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", - ], -) - cc_library( name = "broadcast", srcs = ["broadcast.cc"], @@ -47,26 +31,6 @@ cc_library( ], ) -cc_library( - name = "cholesky", - srcs = ["cholesky.cc"], - hdrs = ["cholesky.h"], - deps = [ - ":batch_dot", - ":triangular_solve", - ":util", - ":while_loop", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/core:lib", - ], -) - cc_library( name = "random", srcs = ["random.cc"], @@ -82,35 +46,12 @@ cc_library( ], ) -cc_library( - name = "qr", - srcs = ["qr.cc"], - hdrs = ["qr.h"], - deps = [ - ":batch_dot", - ":util", - ":while_loop", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/core:lib", - ], -) - cc_library( name = "scatter", srcs = ["scatter.cc"], hdrs = ["scatter.h"], deps = [ ":util", - ":while_loop", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -124,51 +65,6 @@ cc_library( ], ) -cc_library( - name = "triangular_solve", - srcs = ["triangular_solve.cc"], - hdrs = ["triangular_solve.h"], - deps = [ - ":batch_dot", - ":util", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/core:lib", - ], -) - -xla_test( - name = "triangular_solve_test", - srcs = ["triangular_solve_test.cc"], - tags = ["noasan"], # sometimes times out, http://b/78650012 - deps = [ - ":triangular_solve", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - cc_library( name = "util", srcs = ["util.cc"], @@ -186,42 +82,3 @@ cc_library( "@com_google_absl//absl/types:span", ], ) - -xla_test( - name = "util_test", - srcs = ["util_test.cc"], - deps = [ - ":batch_dot", - ":util", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - -cc_library( - name = "while_loop", - srcs = ["while_loop.cc"], - hdrs = ["while_loop.h"], - deps = [ - ":util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc deleted file mode 100644 index 5400e8834cb9807f6dd71abe7789b2672e29e905..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" - -#include -#include - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace tensorflow { - -xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, bool conjugate_y, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { - return errors::InvalidArgument( - "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(x_shape), " vs. ", - xla::ShapeUtil::HumanString(y_shape)); - } - const int ndims = xla::ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to BatchedDot must have rank >= 2: ", ndims); - } - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { - return errors::InvalidArgument( - "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " vs ", - xla::ShapeUtil::HumanString(y_shape)); - } - batch_dimension_numbers.push_back(i); - } - - int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); - int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return errors::InvalidArgument( - "Dimensions ", x_inner_dim, " and ", y_inner_dim, - " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(y_shape), - " transpose: ", transpose_y); - } - - // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::IsZeroElementArray(x_shape) || - xla::ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); - } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return xla::Broadcast( - xla::ConstantLiteral(builder, - xla::LiteralUtil::Zero(x_shape.element_type())), - dimensions); - } - - if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = xla::Conj(x); - } - if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = xla::Conj(y); - } - - xla::PrecisionConfig precision_proto; - precision_proto.add_operand_precision(precision); - precision_proto.add_operand_precision(precision); - - xla::DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); - } - - return xla::DotGeneral(x, y, dot_dnums, &precision_proto); - }); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h deleted file mode 100644 index 6edd63a4d3b66c21aa4cce8c9f36eef0dc363cd8..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace tensorflow { - -// Multiplies slices of two tensors in batches. - -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. Each of the -// individual slices can optionally be transposed before multiplication by -// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each -// can be elementwise-complex-conjugated by setting the `conjugate_x` or -// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both -// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`. -// -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. -// -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: -// -// r_o = c_x if transpose_x else r_x -// c_o = r_y if transpose_y else c_y -// -// It is computed as: -// -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot( - xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, - bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 2b1c2ced925d9fee7392986015a6e716a94d356f..1cd5a79171dccd57fc1b7941cdf16417301ff7f8 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" @@ -49,7 +48,7 @@ xla::StatusOr XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) { + if (num_index_dims > buffer_shape.rank()) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", xla::ShapeUtil::HumanString(indices_shape), @@ -141,8 +140,8 @@ xla::StatusOr XlaScatter( ? indices_shape.dimensions_size() - 1 : indices_shape.dimensions_size()); - int64 updates_rank = xla::ShapeUtil::Rank(updates_shape); - int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape); + int64 updates_rank = updates_shape.rank(); + int64 buffer_rank = buffer_shape.rank(); int64 num_window_dims_in_updates = buffer_rank - num_index_dims; // If the rank of `updates` is 0 and does not match the expected rank of @@ -157,7 +156,7 @@ xla::StatusOr XlaScatter( if (updates_rank == 0 && expected_updates_rank != 0) { new_updates = xla::Broadcast(updates, expected_updates_dims); TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); - updates_rank = xla::ShapeUtil::Rank(updates_shape); + updates_rank = updates_shape.rank(); } if (updates_rank > 0) { diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 804671fbc75b0a5a6e04b204822b6f084013cd8b..06eda41611861060a1f1c4d028b96405d288efdb 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -54,6 +54,9 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::C64: return xla::ConstantR0(builder, value); break; + case xla::C128: + return xla::ConstantR0(builder, value); + break; default: LOG(FATAL) << "unhandled element type " << type; } @@ -90,6 +93,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::C64: literal = xla::LiteralUtil::CreateR0(value); break; + case xla::C128: + literal = xla::LiteralUtil::CreateR0(value); + break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; case xla::S16: @@ -113,36 +119,6 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_RET_CHECK(start.size() == end.size()); - int64 n_minor_dims = start.size(); - - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); - - // Prepends 0s in the major dim - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + major_dims.size()); - - // Prepends the shape of the major dims. - std::vector padded_end(n_dims); - std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); - std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); - - std::vector strides(n_dims, 1); - return xla::Slice(x, padded_start, padded_end, strides); - }); -} std::vector ConcatVectors(absl::Span xs, absl::Span ys) { @@ -152,100 +128,4 @@ std::vector ConcatVectors(absl::Span xs, return output; } -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - int64 n_minor_dims = starts.size(); - TF_RET_CHECK(n_minor_dims == sizes.size()); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - sizes.size()); - auto padded_starts = PrependZerosInMajorDims(x, starts); - auto padded_sizes = ConcatVectors(major_dims, sizes); - return xla::DynamicSlice(x, padded_starts, padded_sizes); - }); -} - -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = xla::ConstantR1(builder, start_as_int32); - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - xla::ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return xla::DynamicUpdateSlice(x, update, start_constant); - }); -} - -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - const int64 n_minor_dims = start.size(); - TF_RET_CHECK(n_minor_dims <= n_dims); - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + (n_dims - n_minor_dims)); - return UpdateSlice(x, update, padded_start); - }); -} - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts) { - auto padded_starts = PrependZerosInMajorDims(x, starts); - return xla::DynamicUpdateSlice(x, update, padded_starts); -} - -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = xla::Reshape(xla::ConstantR0(builder, 0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); - } - return xla::ConcatInDim(builder, padded_starts, 0); - }); -} - -xla::XlaOp TransposeInMinorDims(xla::XlaOp x) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - std::vector permutation(n_dims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return xla::Transpose(x, permutation); - }); -} - -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? xla::Conj(x) : x; - }); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 80e9e5b002d49581209e608b98606e02709c5876..aec8061cb4322b8d315b6cdc80c7fff1e0cb4cb1 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -38,44 +38,10 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64 value); -// Builds a vector of zeros of length rank(x) with the last values being -// those in `starts`. -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts); - -// Performs a slice in the minor dimensions of a Tensor. -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end); - // Returns the concatenation of `xs` and `ys`. std::vector ConcatVectors(absl::Span xs, absl::Span ys); -// Performs a dynamic slice in the minor dimensions of a Tensor. -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes); - -// Updates a slice of 'x', i.e., -// x[start[0], ..., start[n]] = update -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -// Updates a slice of 'x', where 'start' contains a list of minor dimensions: -// x[..., start[0], ..., start[n]] = update -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts); - -// Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::XlaOp TransposeInMinorDims(xla::XlaOp x); - -// Applies a complex conjugation operation if `a` is complex and `conjugate_a` -// is true, otherwise returns its argument. -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 67d08290033361f16dfff42b06af9b253e84963a..749a7c3054a65d6ec9f9dc13f6f4a713ac9d3d5a 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -77,7 +77,7 @@ Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { - TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && + TF_RET_CHECK(literal.shape().IsArray() && xla::ShapeUtil::ElementsIn(literal.shape()) == host_tensor->NumElements()); xla::PrimitiveType primitive_type; diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 15f4c38da29507da9e092c1d5725b5f95a81d1b9..44bccfe6474d175beda392ca17dfbcb08c0b1b11 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -49,7 +49,7 @@ using Types = std::pair, std::pair, std::pair>; -TYPED_TEST_CASE(LiteralUtilTest, Types); +TYPED_TEST_SUITE(LiteralUtilTest, Types); TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) { using int_type = typename TypeParam::first_type; diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 4dce0a2102cf9c782850ccc7af4f14b59bd51e53..7140b6a1227a53290c3747892a55886a7f48513b 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -4,7 +4,11 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_gen_op_wrapper_py", +) cc_library( name = "xla_ops", @@ -24,3 +28,14 @@ tf_gen_op_wrapper_py( ":xla_ops", ], ) + +tf_custom_op_library( + name = "_xla_ops.so", + srcs = [ + "xla_ops.cc", + ], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + ], +) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index bd2c0a5ee88869ba60701c0a7ace05857452eed9..af641131ed76a8d6a7291c360302fa17c94af014 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -369,7 +369,11 @@ REGISTER_OP("XlaKeyValueSort") .Output("sorted_values: V") .Attr("K: realnumbertype") .Attr("V: type") - .SetShapeFn(shape_inference::UnchangedShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + c->set_output(1, c->input(1)); + return Status::OK(); + }) .Doc(R"doc( Wraps the XLA Sort operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#sort @@ -409,5 +413,29 @@ body: A function that takes a list of tensors and returns another list of tensors. Both lists have the same types as specified by T. )doc"); +REGISTER_OP("XlaDequantize") + .Input("input: uint32") + .Output("output: bfloat16") + .Attr("min_range: float") + .Attr("max_range: float") + .Attr("mode: string") + .Attr("transpose_output: bool") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Takes the packed uint32 input and unpacks the input to uint8 to do +Dequantization on deivce. + +input: Input tensors whose types is uint32, shape is [d0, ..., dn]. +output: Output tensors whose types is bloat16. If transpose_output is true, + output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output + is false, output shape is [d0,..., dn * 4]. +min_range: The minimum scalar value possibly produced for the input. +max_range: The maximum scalar value possibly produced for the input. +mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}. +transpose_output: Boolean to determine if output is transposed. transpose_output + is faster when input is large and rank of input is higher than 1. +)doc"); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index fef97b98c376d9df8bbfd9cb6651216895e46bf4..9abdb04d7736e8ff5225688af4759a522d3e7fc7 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -15,6 +15,7 @@ load( "//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc", ) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") tf_py_clif_cc( name = "xla_op_registry", @@ -27,9 +28,13 @@ tf_py_clif_cc( ], ) -py_library( +tf_custom_op_py_library( name = "xla", srcs = ["xla.py"], + dso = ["//tensorflow/compiler/tf2xla/ops:_xla_ops.so"], + kernels = [ + "//tensorflow/compiler/tf2xla/ops:xla_ops", + ], deps = [ "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", "//tensorflow/compiler/xla:xla_data_proto_py", diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 147e562658bbfc445f99268812e2c3ae1ee61e30..345193c936a885e5a9e468979c4b73b5b0c9e5c2 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -386,3 +386,4 @@ def slice(x, start_dims, limit_dims, strides): sort = gen_xla_ops.xla_sort key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while +dequantize = gen_xla_ops.xla_dequantize diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 72b240996fb4d9dcb5f5dfd919da618cbae08c16..c20d6a5fd1f3bd7dad30cb3359d13ed4609a2250 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -65,6 +65,7 @@ CreateResourceOpInfoMap() { add("ResourceApplyFtrlV2" , kReadWrite, kVariable); add("ResourceApplyGradientDescent" , kReadWrite, kVariable); add("ResourceApplyMomentum" , kReadWrite, kVariable); + add("ResourceApplyKerasMomentum" , kReadWrite, kVariable); add("ResourceApplyPowerSign" , kReadWrite, kVariable); add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable); add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable); @@ -76,6 +77,7 @@ CreateResourceOpInfoMap() { add("ResourceScatterMin" , kReadWrite, kVariable); add("ResourceScatterMul" , kReadWrite, kVariable); add("ResourceScatterNdAdd" , kReadWrite, kVariable); + add("ResourceScatterNdSub" , kReadWrite, kVariable); add("ResourceScatterNdUpdate" , kReadWrite, kVariable); add("ResourceScatterSub" , kReadWrite, kVariable); add("ResourceScatterUpdate" , kReadWrite, kVariable); diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index b589512dcdfa32050281120aba6a5ae89a980c2f..8997b2f5c68da480e9d4cb1f7ff8776690363392 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -18,21 +18,81 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +namespace { + +Status PopulateInfeedLayoutVector(const xla::Shape& shape, + std::vector* layouts) { + if (shape.IsTuple()) { + int64 tuple_elements = xla::ShapeUtil::TupleElementCount(shape); + for (int64 i = 0; i < tuple_elements; ++i) { + const xla::Shape& subshape = + xla::ShapeUtil::GetTupleElementShape(shape, i); + TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(subshape, layouts)); + } + } else if (xla::LayoutUtil::HasLayout(shape)) { + for (auto dim : xla::LayoutUtil::MinorToMajor(shape)) { + layouts->push_back(dim); + } + } else { + layouts->insert(layouts->end(), shape.rank(), -1); + } + return Status::OK(); +} + +// Populate the output layout unless the minor_to_major array contains all -1 +// value, in which case the layout is considered missing and the API returns +// false. +xla::StatusOr MakeLayout(absl::Span minor_to_major, + xla::Layout* layout) { + if (std::all_of(minor_to_major.begin(), minor_to_major.end(), + [](int64 dim) { return dim == -1; })) { + return false; + } + std::vector dim_present(minor_to_major.size(), false); + for (auto dim : minor_to_major) { + if (dim < 0 || dim >= minor_to_major.size()) { + return errors::InvalidArgument("Layout dimension out of range: dim=", dim, + " rank=", minor_to_major.size()); + } + if (dim_present[dim]) { + return errors::InvalidArgument("Repeated layout dimension: dim=", dim); + } + dim_present[dim] = true; + } + *layout = xla::LayoutUtil::MakeLayout(minor_to_major); + return true; +} + +Status AssignLayout( + absl::Span minor_to_major, + const std::function& layout_func, + xla::Shape* shape) { + xla::Layout layout; + TF_ASSIGN_OR_RETURN(bool has_layout, MakeLayout(minor_to_major, &layout)); + if (!has_layout && layout_func) { + layout = layout_func(*shape); + } + *shape->mutable_layout() = layout; + return Status::OK(); +} + +} // namespace // Convert an XLA Shape into the equivalent TensorFlow shape. Status XLAShapeToTensorShape(const xla::Shape& shape, TensorShape* tensor_shape) { - if (xla::ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return errors::InvalidArgument("XLA shape ", xla::ShapeUtil::HumanString(shape), " cannot be converted to a TensorShape"); } *tensor_shape = TensorShape(); - for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) { + for (int i = 0; i < shape.rank(); ++i) { tensor_shape->AddDim(shape.dimensions(i)); } return Status::OK(); @@ -61,4 +121,64 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); } +xla::StatusOr> GetShapeLayoutVector(const xla::Shape& shape) { + std::vector layouts; + TF_RETURN_IF_ERROR(PopulateInfeedLayoutVector(shape, &layouts)); + return layouts; +} + +Status GetShapeWithLayout( + const xla::Shape& input_shape, absl::Span minor_to_major, + const std::function& layout_func, + xla::Shape* output_shape) { + if (input_shape.IsTuple()) { + int64 tuple_elements = xla::ShapeUtil::TupleElementCount(input_shape); + std::vector shapes; + shapes.reserve(tuple_elements); + size_t position = 0; + for (int64 i = 0; i < tuple_elements; ++i) { + const xla::Shape& shape = + xla::ShapeUtil::GetTupleElementShape(input_shape, i); + if (shape.IsTuple()) { + return errors::InvalidArgument( + "Nested tuples not supported: ", + xla::ShapeUtil::HumanString(input_shape)); + } + int64 rank = shape.rank(); + if (position + rank > minor_to_major.size()) { + return errors::InvalidArgument( + "Not enough layout attribute elements: position=", position, + " rank=", rank, " elements=", minor_to_major.size()); + } + shapes.push_back(shape); + TF_RETURN_IF_ERROR(AssignLayout( + absl::Span(minor_to_major).subspan(position, rank), + layout_func, &shapes.back())); + position += rank; + + VLOG(4) << "Shape[" << i + << "] = " << xla::ShapeUtil::HumanStringWithLayout(shapes.back()); + } + if (position != minor_to_major.size()) { + return errors::InvalidArgument( + "Too many elements passed in the layout attribute: position=", + position, " size=", minor_to_major.size()); + } + *output_shape = xla::ShapeUtil::MakeTupleShape(shapes); + } else { + int64 rank = input_shape.rank(); + if (rank != minor_to_major.size()) { + return errors::InvalidArgument( + "Wrong number of layout attribute elements: rank=", rank, + " elements=", minor_to_major.size()); + } + *output_shape = input_shape; + TF_RETURN_IF_ERROR(AssignLayout(minor_to_major, layout_func, output_shape)); + + VLOG(4) << "Shape[] = " + << xla::ShapeUtil::HumanStringWithLayout(*output_shape); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 0b231ea8e7a2d8e303e91911e2e0a36fc83e78b4..e775c4462c3dc15cf4b8d9e8d8e7d9a61e024cd0 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -18,7 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ +#include + #include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" @@ -41,6 +44,25 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const TensorShape& tensor_shape); +// Given an XLA shape with layouts, builds a layout vector in the form able to +// be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/.... +// THe returned vector is a linearized sequence of the minor-to-major values of +// the layouts held within the input shape. +// In case the input shape is a tuple, the minor-to-major values will be in the +// order of the tuple elements within the tuple shape. +// If a shape (or a subshape of a tuple shape) has missing layout, a rank long +// sequence of -1 values will be emittted. +xla::StatusOr> GetShapeLayoutVector(const xla::Shape& shape); + +// Given the input shape and a linearized sequence of the minor-to-major values +// of the layouts, create the output shape by rewriting the input shape layouts. +// If a layout is missing (has -1 values) for a matching tuple subshape, the +// layout_func will be called, if not nullptr. +Status GetShapeWithLayout( + const xla::Shape& input_shape, absl::Span minor_to_major, + const std::function& layout_func, + xla::Shape* output_shape); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index b233e6b2c28e1968bb74901fc684e808ae45ab60..412f31adbb7df52b2d6933be054cc6d40947dc44 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -24,6 +24,51 @@ const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes"; const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; +const char kXlaHasHostTransferAttrName[] = "_xla_has_host_transfer"; + +Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { + if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) { + return errors::InvalidArgument("Node ", node->DebugString(), + " does not have attribute ", + kXlaHasHostTransferAttrName); + } + + if (node->type_string() == "_XlaRecvAtHost" || + node->type_string() == "_XlaSendFromHost") { + node->ClearAttr("device_ordinal"); + node->AddAttr("device_ordinal", device_ordinal); + } else if (node->type_string() == "If") { + AttrValue device_ordinal_value; + device_ordinal_value.set_i(device_ordinal); + for (const string& attr_name : + std::vector{"then_branch", "else_branch"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + node->ClearAttr(attr_name); + node->AddAttr(attr_name, branch_func); + } + } else if (node->type_string() == "While") { + AttrValue device_ordinal_value; + device_ordinal_value.set_i(device_ordinal); + for (const string& attr_name : std::vector{"cond", "body"}) { + NameAttrList branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); + (*branch_func.mutable_attr())["device_ordinal"] = device_ordinal_value; + node->ClearAttr(attr_name); + node->AddAttr(attr_name, branch_func); + } + } else if (HasNodeAttr(node->def(), "device_ordinal")) { + // Function call node containing outside compilation. + node->ClearAttr("device_ordinal"); + node->AddAttr("device_ordinal", device_ordinal); + } else { + return errors::Internal("Unknown node type to set 'device_ordinal': ", + node->DebugString()); + } + return Status::OK(); +} + std::set CalculateTokenInputsForOutputToken(const Graph& g) { std::set results; Node* first_side_effecting_node_on_path = nullptr; diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index f22ddb2f58e1fa5c10ca0fdb956d9136942388b7..75e1f253fb08ae61b0336a8783b7449c69197dd1 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -35,6 +35,13 @@ extern const char kXlaTokenInputNodesAttrName[]; // node has side-effect dependency on current graph's token input. extern const char kXlaTokenArgNodeName[]; +// This node have XlaRecvAtHost/XlaSendFromHost in its associated functions. +extern const char kXlaHasHostTransferAttrName[]; + +// Sets device ordinal attribute for nodes with attribute +// `kXlaHasHostTransferAttrName`. +Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal); + // Calculates side-effect dependencies for the graph's token output. // Returns a set of node names representing these dependencies. std::set CalculateTokenInputsForOutputToken(const Graph& g); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 9fac16a9700419b189bf5393c2b8bd7d76c6c1cc..cf48576ec2746fb29779633275eac4c638b91e45 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -243,7 +243,9 @@ Status CreateXlaArgs(const Graph& graph, XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); + TensorShape shape; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape)); + arg.shape = shape; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index ab26d939ccba75ce58609ffd71c7ccadbe90cfa8..24afe595b18b823818bd8fe65bc599af8bce040a 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -91,7 +91,7 @@ TEST(ConvertGraphDefToXla, Sum) { client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); xla::Literal result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42\n)", result.ToString()); + EXPECT_EQ("(\ns32[] 42\n)", result.ToString()); config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index cc81772e8c5da710bc733f7e4f5fe820b2c2d110..c64f78e1a1bcdd40b1c885889ec5fa491cfa1f66 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -265,6 +265,13 @@ Status PropagateConstIntoWhileNode(Graph* g, Node* while_node, } // Check if i-th retval's input comes from i-th arg directly. + // For resource variable input of While nodes, TF2XLA convention is to place + // them at the end of all inputs (after all data inputs), and *not* return + // them. So number of While node inputs might be larger than number of its + // outputs. + if (i >= body_func->signature().output_arg_size()) { + continue; + } const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i); auto output_arg_input = body_func->ret().find(output_arg.name()); if (output_arg_input == body_func->ret().end()) { @@ -364,6 +371,7 @@ Status AddPlaceholdersForFeeds( GraphDef gd; *gd.mutable_versions() = graph_def->versions(); *gd.add_node() = *existing; + MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0)); TF_RETURN_IF_ERROR( AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/)); @@ -390,6 +398,7 @@ Status AddPlaceholdersForFeeds( // in this code. for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { const PlaceholderInfo& info = it->second; + // TODO(shikharagarwal): Add original node information. NodeDef* d = graph_def->add_node(); d->set_name(info.placeholder_name); d->set_op("PlaceholderV2"); @@ -557,6 +566,12 @@ bool HasAssociatedFunction(const NodeDef& node_def, return true; } + if (node_def.op() == "XlaHostCompute") { + // XlaHostCompute has "shape_inference_graph" func attr, but that's not + // related to graph execution. + return false; + } + for (const auto& iter : node_def.attr()) { if (iter.second.has_func()) { return true; @@ -578,6 +593,9 @@ std::vector GetAssociatedFunctions( // This is a SymbolicGradient op. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs)); + } else if (node.type_string() == "XlaHostCompute") { + // XlaHostCompute has "shape_inference_graph" func attr, but that's not + // related to graph execution. } else { // Collect all function attrs for the node. for (auto& iter : node.attrs()) { @@ -599,7 +617,9 @@ Status RewriteAssociatedFunction( switch (associated_function.type()) { case AssociatedFunctionInfo::kFunctionCallNode: { // Change this node to call the new function. - NodeDefBuilder builder(node->name(), rewritten_function_name, fld); + NodeDebugInfo debug_info(*node); + NodeDefBuilder builder(node->name(), rewritten_function_name, fld, + &debug_info); for (auto attr : node->attrs()) { builder.Attr(attr.first, attr.second); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 202e929315cacd4d6cdfc69d50639d8a427ec6c2..9e9c3cecee68aee856141a620f7292f771978acb 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -21,11 +21,13 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" @@ -329,5 +331,37 @@ TEST(CachedFunctionHandles, Basic) { TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles()); } +TEST(PropagateConstIntoFunctionalNodes, WhileLoopWithResourceInput) { + FunctionLibraryDefinition fld(OpRegistry::Global(), {}); + { + // Cond graph & body graph. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = ops::_Arg(scope.WithOpName("pred"), DT_BOOL, 0); + auto input = ops::_Arg(scope.WithOpName("input"), DT_RESOURCE, 1); + auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef cond_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef)); + FunctionDef body_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(body_fdef)); + } + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = ops::Const(scope.WithOpName("pred"), false, TensorShape({})); + auto input = ops::Const(scope.WithOpName("input"), 0, TensorShape({})); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("cond"); + body_fn.set_name("body"); + auto while_op = + ops::While(scope.WithOpName("while"), + std::initializer_list{pred, input}, cond_fn, body_fn); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + + TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index d00b1376620c0c9d112c7d7426758f6d3f25e86f..732f957d7329c93ad104dacf5190948fbfd7974b 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -69,6 +69,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_COMPLEX64: *type = xla::C64; return Status::OK(); + case tensorflow::DT_COMPLEX128: + *type = xla::C128; + return Status::OK(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType ", diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index a1d359e97c4fad3ca74d44a358cba0e8190cdc22..de2e485a47c18ae8e58a06aba408dbb61a30d00a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -59,45 +59,8 @@ class XlaCompiledCpuFunction { // AOT this is backed by data compiled into the object file. // // The contents of StaticData are XLA-internal implementation details and - // should not be relied on by clients. - // - // TODO(sanjoy): Come up with a cleaner way to express the contraint we want - // here: generated XlaCompiledCpuFunction subclasses should be able to create - // instances of StaticData but only XlaCompiledCpuFunction should be able to - // read from StaticData instances. + // should not be relied on by clients (and therefore are private). class StaticData { - public: - void set_raw_function(RawFunction raw_function) { - raw_function_ = raw_function; - } - void set_buffer_infos( - const cpu_function_runtime::BufferInfo* buffer_infos) { - buffer_infos_ = buffer_infos; - } - void set_num_buffers(size_t num_buffers) { num_buffers_ = num_buffers; } - void set_arg_index_table(const int32* arg_index_table) { - arg_index_table_ = arg_index_table; - } - void set_num_args(int64 num_args) { num_args_ = num_args; } - void set_result_index(size_t result_index) { result_index_ = result_index; } - void set_arg_names(const char** arg_names) { arg_names_ = arg_names; } - void set_result_names(const char** result_names) { - result_names_ = result_names; - } - void set_program_shape(const xla::ProgramShapeProto* program_shape) { - program_shape_ = program_shape; - } - const xla::HloProfilePrinterData* hlo_profile_printer_data() const { - return hlo_profile_printer_data_; - } - void set_hlo_profile_printer_data( - const xla::HloProfilePrinterData* hlo_profile_printer_data) { - hlo_profile_printer_data_ = hlo_profile_printer_data; - } - void set_profile_counters_size(int64 profile_counters_size) { - profile_counters_size_ = profile_counters_size; - } - private: // The raw function to call. RawFunction raw_function_; @@ -134,7 +97,8 @@ class XlaCompiledCpuFunction { // declared so we don't have access to that information here. int64 profile_counters_size_ = 0; - // Only XlaCompiledCpuFunction is allowed to read the above fields. + // Only XlaCompiledCpuFunction is allowed to read and write the above + // fields. friend class XlaCompiledCpuFunction; }; @@ -148,7 +112,7 @@ class XlaCompiledCpuFunction { RESULTS_PROFILES_AND_TEMPS_ONLY, }; - XlaCompiledCpuFunction( + explicit XlaCompiledCpuFunction( const StaticData& static_data, AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); virtual ~XlaCompiledCpuFunction(); @@ -206,8 +170,14 @@ class XlaCompiledCpuFunction { // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. - void set_arg_data(size_t index, void* data) { - buffer_table_[arg_index_table_[index]] = data; + void set_arg_data(size_t index, const void* data) { + // The const_cast is safe because the generated code does not write to arg + // buffers. + // + // buffer_table_ contains pointers to buffers that _will_ be written to by + // generated code so it would be misleading to make buffer_table_ a `const + // void**`. + buffer_table_[arg_index_table_[index]] = const_cast(data); } // ------------------------------ @@ -274,6 +244,76 @@ class XlaCompiledCpuFunction { return *hlo_profile_printer_data_; } + protected: + // --------------------------------------------------------------------------- + // Accessors for reading from and writing to instances of `StaticData`. + // + // Classes generated by tfcompile can call these because the generated classes + // inherit from `XlaCompiledCpuFunction`. `XlaJitCompiledCpuFunction` can + // call these because it is explicitly added as a friend. + + static void set_static_data_raw_function(StaticData* static_data, + RawFunction raw_function) { + static_data->raw_function_ = raw_function; + } + + static void set_static_data_buffer_infos( + StaticData* static_data, + const cpu_function_runtime::BufferInfo* buffer_infos) { + static_data->buffer_infos_ = buffer_infos; + } + + static void set_static_data_num_buffers(StaticData* static_data, + size_t num_buffers) { + static_data->num_buffers_ = num_buffers; + } + + static void set_static_data_arg_index_table(StaticData* static_data, + const int32* arg_index_table) { + static_data->arg_index_table_ = arg_index_table; + } + + static void set_static_data_num_args(StaticData* static_data, + int64 num_args) { + static_data->num_args_ = num_args; + } + + static void set_static_data_result_index(StaticData* static_data, + size_t result_index) { + static_data->result_index_ = result_index; + } + + static void set_static_data_arg_names(StaticData* static_data, + const char** arg_names) { + static_data->arg_names_ = arg_names; + } + + static void set_static_data_result_names(StaticData* static_data, + const char** result_names) { + static_data->result_names_ = result_names; + } + + static void set_static_data_program_shape( + StaticData* static_data, const xla::ProgramShapeProto* program_shape) { + static_data->program_shape_ = program_shape; + } + + static void set_static_data_hlo_profile_printer_data( + StaticData* static_data, + const xla::HloProfilePrinterData* hlo_profile_printer_data) { + static_data->hlo_profile_printer_data_ = hlo_profile_printer_data; + } + + static const xla::HloProfilePrinterData* + get_static_data_hlo_profile_printer_data(StaticData* static_data) { + return static_data->hlo_profile_printer_data_; + } + + static void set_static_data_profile_counters_size( + StaticData* static_data, int64 profile_counters_size) { + static_data->profile_counters_size_ = profile_counters_size; + } + private: const RawFunction raw_function_; const size_t result_index_; @@ -307,6 +347,10 @@ class XlaCompiledCpuFunction { const char** result_names_ = nullptr; const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; + + // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the + // `set_static_data_*` static methods above. + friend class XlaJitCompiledCpuFunction; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index ee461a3c07d4db514c7697e005a9371be4b54dd0..514b156deb9f350813237c31b7657a5b09c800dd 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -57,7 +58,11 @@ Status CheckSignature(const DataTypeVector& types, " elements while function has ", types.size()); } for (int i = 0; i < types.size(); ++i) { - if (types[i] != args[i].type && types[i] != DT_RESOURCE) { + // Don't perform type checks on resource variables and tensor + // lists (DT_VARIANT) as we have to trick the type system in order to + // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor. + if (types[i] != args[i].type && types[i] != DT_RESOURCE && + types[i] != DT_VARIANT) { return errors::Internal( "Argument ", i, " has declared type ", DataTypeString(args[i].type), " but function parameter has type ", DataTypeString(types[i])); @@ -192,6 +197,8 @@ Status BuildComputation( output.shape = output.constant_value.shape(); break; + case XlaExpression::Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case XlaExpression::Kind::kXlaOp: { output.is_constant = false; TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); @@ -333,8 +340,21 @@ bool XlaCompiler::Argument::operator==( other.tensor_array_gradients)) { return false; } - if (shape != other.shape) { - return false; + if (absl::holds_alternative(shape)) { + if (!absl::holds_alternative(other.shape)) { + return false; + } + if (!xla::Shape::Equal()(absl::get(shape), + absl::get(other.shape))) { + return false; + } + } else { + if (!absl::holds_alternative(other.shape)) { + return false; + } + if (absl::get(shape) != absl::get(other.shape)) { + return false; + } } if (constant_value.shape() != other.constant_value.shape()) { return false; @@ -348,7 +368,7 @@ string XlaCompiler::Argument::HumanString() const { common = absl::StrCat(" name=", name); } absl::StrAppend(&common, " type=", DataTypeString(type), - " shape=", shape.DebugString()); + " shape=", ShapeHumanString()); switch (kind) { case kInvalid: return "invalid"; @@ -375,6 +395,23 @@ string XlaCompiler::Argument::HumanString() const { } } +std::vector XlaCompiler::Argument::DimensionSizes() const { + if (absl::holds_alternative(shape)) { + return xla::InlinedVectorToVector( + absl::get(shape).dim_sizes()); + } else { + return absl::get(shape).dimensions(); + } +} + +string XlaCompiler::Argument::ShapeHumanString() const { + if (absl::holds_alternative(shape)) { + return absl::get(shape).DebugString(); + } else { + return absl::get(shape).DebugString(); + } +} + XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), @@ -462,8 +499,34 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { opts.set_do_function_inlining(true); opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); + // Do not constant fold nodes that output DT_VARIANT type tensors. + // XLA does not support Const nodes of Variant type since it needs + // to know the original ops to be able to compile them to the relevant + // XLA form. + // TODO(srbs): This filter is a little conservative. E.g. a subgraph of + // the form: + // Const + // | + // EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op + // | + // (Discard popped list) + // + // Would have been reduced to "Const -> Op" without this filter. + // However since we are only allowed to specify the filter at the "Node" + // level there is no good way to allow the above behavior. So we + // disallow any sort of constant folding on Variant nodes for now. + auto cf_consider_fn = [](const Node* n) { + for (const auto& output_arg : n->op_def().output_arg()) { + if (output_arg.type() == DT_VARIANT) { + return false; + } + } + return true; + }; + GraphOptimizer::Options graph_optimizer_options; + graph_optimizer_options.cf_consider_fn = cf_consider_fn; optimizer.Optimize(flib_runtime_, flib_runtime_->env(), - /*device=*/nullptr, &graph, /*shape_map=*/nullptr); + /*device=*/nullptr, &graph, graph_optimizer_options); return graph; } @@ -548,11 +611,22 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { if (is_entry_computation) { - TF_ASSIGN_OR_RETURN( - *xla_shape, options_.shape_representation_fn(arg.shape, arg.type)); + TensorShape shape; + if (absl::holds_alternative(arg.shape)) { + shape = absl::get(arg.shape); + } else { + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(absl::get(arg.shape), &shape)); + } + TF_ASSIGN_OR_RETURN(*xla_shape, + options_.shape_representation_fn(shape, arg.type)); } else { - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, arg.shape, xla_shape)); + if (absl::holds_alternative(arg.shape)) { + *xla_shape = absl::get(arg.shape); + } else { + TF_RETURN_IF_ERROR(TensorShapeToXLAShape( + arg.type, absl::get(arg.shape), xla_shape)); + } } return Status::OK(); } @@ -561,8 +635,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, switch (arg.resource_kind) { case XlaResource::kVariable: { - TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( - arg.shape, arg.type)); + TF_RET_CHECK(absl::holds_alternative(arg.shape)); + TF_ASSIGN_OR_RETURN(*xla_shape, + options_.shape_representation_fn( + absl::get(arg.shape), arg.type)); return Status::OK(); } @@ -571,9 +647,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return errors::InvalidArgument( "Negative max_array_size in XLAShapeForArgument"); } + TF_RET_CHECK(absl::holds_alternative(arg.shape)); TensorShape shape; shape.AddDim(arg.max_array_size); - shape.AppendShape(arg.shape); + shape.AppendShape(absl::get(arg.shape)); TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); if (!arg.tensor_array_gradients.empty()) { @@ -588,9 +665,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return errors::InvalidArgument( "Negative max_array_size in XLAShapeForArgument"); } + TF_RET_CHECK(absl::holds_alternative(arg.shape)); TensorShape shape; shape.AddDim(arg.max_array_size); - shape.AppendShape(arg.shape); + shape.AppendShape(absl::get(arg.shape)); xla::Shape buffer_shape; TF_RETURN_IF_ERROR( TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); @@ -620,14 +698,15 @@ Status XlaCompiler::BuildArguments( bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, const std::map& arg_cores, std::vector* arg_expressions, - std::vector* input_mapping, std::vector* input_shapes, + std::vector* input_to_args, std::vector* input_shapes, bool is_entry_computation) { arg_expressions->resize(args.size()); // Argument numbers of arguments and resources that are to be passed to the - // XLA computation as runtime parameters. - input_mapping->clear(); - input_mapping->reserve(args.size()); + // XLA computation as runtime parameters. `input_to_args[a] = b` means that + // the a'th XLA input corresponds to the b'th original arg indexes. + input_to_args->clear(); + input_to_args->reserve(args.size()); // Fills in constant arguments, and computes non-constant argument order. for (std::vector::size_type i = 0; i < args.size(); @@ -637,24 +716,25 @@ Status XlaCompiler::BuildArguments( switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid); + TF_RET_CHECK(absl::holds_alternative(arg.shape)); // TODO(phawkins): this code assumes that resource arguments do not // alias. XlaResource* resource = context->AddResource(absl::make_unique( - arg.resource_kind, i, arg.name, arg.type, arg.shape, - xla::XlaOp(), + arg.resource_kind, i, arg.name, arg.type, + absl::get(arg.shape), xla::XlaOp(), /*max_array_size=*/arg.max_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, /*tensor_array_multiple_writes_aggregate=*/true)); arg_expression = XlaExpression::Resource(resource); if (arg.initialized) { - input_mapping->push_back(i); + input_to_args->push_back(i); } break; } case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kToken: { - input_mapping->push_back(i); + input_to_args->push_back(i); break; } case XlaCompiler::Argument::kConstant: @@ -666,15 +746,23 @@ Status XlaCompiler::BuildArguments( } } - if (input_mapping->empty()) { + if (input_to_args->empty()) { return Status::OK(); } - std::vector arg_shapes(input_mapping->size()); - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + // `arg_to_inputs[c] = d` means that the c'th original arg index corresponds + // to the d'th XLA input. Note that the value -1 corresponds to constants, or + // other args that don't correspond to an input. + std::vector arg_to_inputs(args.size(), -1); + for (int i = 0; i < input_to_args->size(); i++) { + arg_to_inputs[input_to_args->at(i)] = i; + } + + std::vector arg_shapes(input_to_args->size()); + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { // Computes the shapes of non-constant arguments. TF_RETURN_IF_ERROR(XLAShapeForArgument( - args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i])); + args[(*input_to_args)[i]], is_entry_computation, &arg_shapes[i])); } if (use_tuple_arg) { @@ -691,13 +779,13 @@ Status XlaCompiler::BuildArguments( builder->SetOpMetadata(arg_metadata); // Build parameter handles for non-constant arguments. - std::vector arg_handles(input_mapping->size()); + std::vector arg_handles(input_to_args->size()); if (use_tuple_arg) { xla::XlaOp tuple; if (is_entry_computation) { xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); - for (int64 parameter : *input_mapping) { + for (int64 parameter : *input_to_args) { auto it = arg_cores.find(parameter); const int core = it == arg_cores.end() ? 0 : it->second; *tuple_sharding.add_tuple_shardings() = @@ -709,7 +797,19 @@ Status XlaCompiler::BuildArguments( } else { tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + + for (int i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; + for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { + int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); + TF_RETURN_IF_ERROR(builder->SetDynamicBinding( + /*dynamic_size_param_num=*/0, {dynamic_size_param_index}, + /*target_param_num=*/0, /*target_param_index=*/{i}, + dim_and_arg_num.first)); + } + } + + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { auto it = arg_cores.find(i); const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( @@ -718,7 +818,7 @@ Status XlaCompiler::BuildArguments( arg_handles[i] = xla::GetTupleElement(tuple, i); } } else { - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { auto it = arg_cores.find(i); const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( @@ -727,6 +827,17 @@ Status XlaCompiler::BuildArguments( arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], absl::StrCat("arg", i)); } + + for (int i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; + for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { + int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); + TF_RETURN_IF_ERROR(builder->SetDynamicBinding( + /*dynamic_size_param_num=*/dynamic_size_param_index, {}, + /*target_param_num=*/i, /*target_param_index=*/{}, + dim_and_arg_num.first)); + } + } } builder->ClearOpMetadata(); @@ -734,12 +845,12 @@ Status XlaCompiler::BuildArguments( // Fill in the handles in non-constant arguments, and reshape parameters // back to their correct shapes. VLOG(2) << "XLA computation inputs:"; - for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; + for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; VLOG(2) << " XLA arg " << i << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) - << " name: " << arg.name << " TF arg " << input_mapping->at(i); - XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)]; + << " name: " << arg.name << " TF arg " << input_to_args->at(i); + XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)]; switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); @@ -756,7 +867,7 @@ Status XlaCompiler::BuildArguments( // return values of functions, and then reshape unconditionally. if (is_entry_computation) { arg_expression = XlaExpression::XlaOp( - xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type); + xla::Reshape(arg_handles[i], arg.DimensionSizes()), arg.type); } else { arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 0d801b73a8c2651305328384377751254ecaa41d..ad3144b41bdf3fc8b75ab5230e8e128df2962884 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "absl/types/variant.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" @@ -124,7 +125,8 @@ class XlaCompiler { DataType type = DT_INVALID; // The shape of the argument. For: - // * a parameter: the shape of the parameter. + // * a parameter: the shape of the parameter. We allow setting the xla shape + // if known. This helps avoid conversions to and from TensorShape. // * a constant: ignored; the shape given by constant_value is used // instead. // * an uninitialized resource: ignored. We don't yet know the shape of an @@ -133,7 +135,7 @@ class XlaCompiler { // * an initialized TensorArray or Stack resource: the shape of an entry in // the TensorArray/Stack. Note this is the size of a single entry, not the // XLA data structure that represents the complete stack/array. - TensorShape shape; + absl::variant shape; // The value of the argument, if it is a compile-time constant. Must be a // host-memory tensor. @@ -157,10 +159,20 @@ class XlaCompiler { // as `tensor_array_gradients`. std::set tensor_array_gradients; + // dynamic dims to arg number map. Empty if no dynamic shapes. + std::map dynamic_dim_to_arg_num_map; + bool is_pad_arg = false; + bool operator==(const Argument& other) const; // Returns a human-readable summary of the argument. string HumanString() const; + + // Returns the dimension sizes for either TensorShape or xla::Shape. + std::vector DimensionSizes() const; + + // Returns the human-readable string for either TensorShape or xla::Shape. + string ShapeHumanString() const; }; // Options pertaining to an individual call to CompileGraph() or @@ -420,7 +432,7 @@ class XlaCompiler { XlaContext* context, const std::map& arg_cores, std::vector* arg_expressions, - std::vector* input_mapping, + std::vector* input_to_args, std::vector* input_shapes, bool is_entry_computation); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index fe2a5f5b0c9ea6b5f2bb71df836fdcabf9a0cf23..492010f7317d32a8a620147cd2cd9356d4f13fde 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -82,7 +82,7 @@ namespace { // compiled kernels. class DummyResourceForTest : public ResourceBase { public: - string DebugString() override { return "dummy"; } + string DebugString() const override { return "dummy"; } void Increment() { ++value_; } int Get() { return value_; } @@ -1362,7 +1362,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), args, &result)); EXPECT_EQ(result.xla_input_shapes.size(), 1); - EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_TRUE(result.xla_output_shape.IsTuple()); EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); } { @@ -1380,11 +1380,11 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), args, &result)); EXPECT_EQ(result.xla_input_shapes.size(), 2); - EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[1])); - EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_TRUE(result.xla_input_shapes[1].IsToken()); + EXPECT_TRUE(result.xla_output_shape.IsTuple()); EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 2); - EXPECT_TRUE(xla::ShapeUtil::IsToken( - xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1))); + EXPECT_TRUE(xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1) + .IsToken()); } } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index a69af70503376b6c0905deb8980abdc3254a6e47..6139bf3cea0790c2697130a993e92be96c81848b 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -61,7 +61,7 @@ void XlaContext::set_args(std::vector args) { XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder) : compiler_(compiler), builder_(builder) {} -string XlaContext::DebugString() { return "XLA JIT context"; } +string XlaContext::DebugString() const { return "XLA JIT context"; } void XlaContext::SetRetval(int index, const XlaExpression& expression) { if (retvals_.size() <= index) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 0767d1faac14cedb8666f6cc37175eb7b55f6158..eb4ad3fe6a14b42a4df2c73c71cb6df1331fd796 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -47,7 +47,7 @@ class XlaContext : public ResourceBase { XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder); // Virtual method defined by ResourceBase. - string DebugString() override; + string DebugString() const override; XlaCompiler* compiler() const { return compiler_; } diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index ca0309166b7c73d1a5a818091e2a30fa112a4de4..3d228c92adcbe3d093a4fe70d157e57ab3e80c80 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -46,6 +46,14 @@ XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) { return e; } +XlaExpression XlaExpression::TensorList(xla::XlaOp tensor_list) { + XlaExpression e; + e.kind_ = Kind::kTensorList; + e.dtype_ = DT_VARIANT; + e.handle_ = tensor_list; + return e; +} + XlaExpression XlaExpression::Resource(XlaResource* resource) { XlaExpression e; e.kind_ = Kind::kResource; @@ -64,6 +72,8 @@ string XlaExpression::HumanString() const { return "xla_op"; case Kind::kResource: return "resource"; + case Kind::kTensorList: + return "tensor_list"; } } @@ -76,6 +86,8 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { HostTensorToBorrowingLiteral(constant_value_, &literal)); return xla::ConstantLiteral(builder, literal); } + case Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case Kind::kXlaOp: if (builder != handle_.builder()) { return errors::InvalidArgument( @@ -96,7 +108,10 @@ xla::StatusOr> XlaExpression::ResolveConstant( return {constant_value()}; case Kind::kXlaOp: break; + case Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; case Kind::kResource: + TF_FALLTHROUGH_INTENDED; case Kind::kInvalid: return errors::InvalidArgument( "ResolveConstant called on XlaExpression: ", HumanString()); @@ -134,6 +149,8 @@ xla::StatusOr XlaExpression::GetShape() const { TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); return shape; } + case Kind::kTensorList: + return TensorShape({}); case Kind::kResource: return TensorShape({}); case Kind::kInvalid: diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index bed6761d362a98d344003c1edea342e68c31ef07..ac0232d8924cf2c9e35ad3f0772a3a2adc18af87 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -32,11 +32,16 @@ namespace tensorflow { // * a constant tensor. // * an xla::XlaOp, representing a symbolic XLA value. // * a resource, e.g., a variable, represented as an XlaResource pointer. +// * a tensor list, represented by a tuple of tensors and the list length. // // Constant tensors are mostly an optimization to avoid passing large constants // to XLA, but are also sometimes used to represent tensors that have no XLA // representation, for example, DT_STRING tensors. A canonical use case might be // an error message string. +// +// Tensor lists are very similar to xla::XlaOp, however they require some +// specific logic around shape management since the tuples are not supported by +// TensorFlow. class XlaExpression { public: enum class Kind { @@ -44,6 +49,7 @@ class XlaExpression { kConstant, kXlaOp, kResource, + kTensorList, }; XlaExpression(); @@ -62,6 +68,9 @@ class XlaExpression { // be derived from the XLA type. static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); + // Builds a tensor list expression. + static XlaExpression TensorList(xla::XlaOp tensor_list); + // Builds a resource expression. static XlaExpression Resource(XlaResource* resource); @@ -100,7 +109,8 @@ class XlaExpression { DataType dtype_ = DT_INVALID; - // The XLA handle of the expression's computation, if kind_ == kXlaOp. + // The XLA handle of the expression's computation, if kind_ == kXlaOp or + // a tuple expression if kind_ == kTensorList. xla::XlaOp handle_; // The value of the constant, if kind_ == kConstant. diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index c2c0751211180c3715a19d6c78e34659fd18914e..04a5d934064a9083a41cc210b48df65bbc862fff 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -34,63 +34,6 @@ limitations under the License. namespace tensorflow { -namespace { - -xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis, - bool is_min) { - xla::XlaBuilder* builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); - xla::XlaOp init_value; - xla::XlaComputation reducer; - if (is_min) { - init_value = xla::MaxValue(builder, input_shape.element_type()); - reducer = - xla::CreateScalarMinComputation(input_shape.element_type(), builder); - } else { - init_value = xla::MinValue(builder, input_shape.element_type()); - reducer = - xla::CreateScalarMaxComputation(input_shape.element_type(), builder); - } - - xla::XlaOp input_max = xla::Reduce(input, init_value, reducer, - /*dimensions_to_reduce=*/{axis}); - std::vector broadcast_dims(xla::ShapeUtil::Rank(input_shape) - 1); - std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); - std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - // Compute a mask that has 1s for elements equal to the maximum. - xla::XlaOp partial_mask = xla::ConvertElementType( - xla::Eq(input, input_max, broadcast_dims), output_type); - - // In order to make identity elements for a bitwise And, we: - // Left shift the 1 to the leftmost bit, yielding 0x10...0 - // Arithmetic right shift the 1 back to the rightmost bit, yielding - // 0xFF...F - int32 bits_in_type = - xla::ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; - xla::XlaOp shift_amount = - xla::ConstantR0WithType(builder, output_type, bits_in_type); - xla::XlaOp full_mask = xla::ShiftRightArithmetic( - xla::ShiftLeft(partial_mask, shift_amount), shift_amount); - - // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its - // index. - - const int64 axis_size = xla::ShapeUtil::GetDimension(input_shape, axis); - xla::XlaOp iota = xla::Iota(builder, output_type, axis_size); - xla::XlaOp product = - xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis}); - - // If there are multiple maximum elements, choose the one with the highest - // index. - return xla::Reduce(product, xla::MinValue(builder, output_type), - xla::CreateScalarMaxComputation(output_type, builder), - /*dimensions_to_reduce=*/{axis}); - }); -} - -} // namespace - xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); @@ -120,7 +63,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, /* static */ Status XlaHelpers::ReshapeLiteral( const xla::Literal& input, absl::Span dimensions, xla::Literal* output) { - if (xla::ShapeUtil::IsTuple(input.shape())) { + if (input.shape().IsTuple()) { return errors::InvalidArgument("ReshapeLiteral does not support tuples."); } xla::Shape shape = @@ -148,16 +91,6 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { return linspace; } -xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, - int axis) { - return ArgMinMax(input, output_type, axis, /*is_min=*/false); -} - -xla::XlaOp XlaHelpers::ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, - int axis) { - return ArgMinMax(input, output_type, axis, /*is_min=*/true); -} - Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, const xla::XlaOp& indices, const xla::XlaOp& on_value, diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 4858dfee55a393d04cd2af83916eeb40820ee368..490923526bd3acd4b167ccb3faff1d6c9e631131 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -53,16 +53,6 @@ class XlaHelpers { absl::Span shape, xla::Literal* output); - // Returns the argmax of `input` along `axis`. `output_type` is the type to - // use for the output. - static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, - int axis); - - // Returns the argmin of `input` along `axis`. `output_type` is the type to - // use for the output. - static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, - int axis); - // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new // axis. `indices_shape` is the shape of `indices`. `on_value` and diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index fabbcd04fed96ad814d04c2df9394f43bfe0cf99..884dc45cb11b18ae557c3da3f4192b3805cb7980 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -135,24 +135,34 @@ XlaJitCompiledCpuFunction::Compile( jit->arg_index_table_ = std::move(arg_index_table); jit->program_shape_ = absl::make_unique(program_shape->ToProto()); - jit->static_data_.set_raw_function(raw_function); - jit->static_data_.set_buffer_infos(jit->buffer_infos_.data()); - jit->static_data_.set_num_buffers(jit->buffer_infos_.size()); - jit->static_data_.set_arg_index_table(jit->arg_index_table_.data()); - jit->static_data_.set_num_args(jit->arg_index_table_.size()); - jit->static_data_.set_result_index(result_index); + XlaCompiledCpuFunction::set_static_data_raw_function(&jit->static_data_, + raw_function); + XlaCompiledCpuFunction::set_static_data_buffer_infos( + &jit->static_data_, jit->buffer_infos_.data()); + XlaCompiledCpuFunction::set_static_data_num_buffers( + &jit->static_data_, jit->buffer_infos_.size()); + XlaCompiledCpuFunction::set_static_data_arg_index_table( + &jit->static_data_, jit->arg_index_table_.data()); + XlaCompiledCpuFunction::set_static_data_num_args( + &jit->static_data_, jit->arg_index_table_.size()); + XlaCompiledCpuFunction::set_static_data_result_index(&jit->static_data_, + result_index); // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.fetch(), &jit->nonempty_result_names_, &jit->result_names_); - jit->static_data_.set_arg_names(jit->arg_names_.data()); - jit->static_data_.set_result_names(jit->result_names_.data()); - jit->static_data_.set_program_shape(jit->program_shape_.get()); + XlaCompiledCpuFunction::set_static_data_arg_names(&jit->static_data_, + jit->arg_names_.data()); + XlaCompiledCpuFunction::set_static_data_result_names( + &jit->static_data_, jit->result_names_.data()); + XlaCompiledCpuFunction::set_static_data_program_shape( + &jit->static_data_, jit->program_shape_.get()); if (cpu_executable->hlo_profiling_enabled()) { - jit->static_data_.set_hlo_profile_printer_data( - &cpu_executable->hlo_profile_printer_data()); - jit->static_data_.set_profile_counters_size( + XlaCompiledCpuFunction::set_static_data_hlo_profile_printer_data( + &jit->static_data_, &cpu_executable->hlo_profile_printer_data()); + XlaCompiledCpuFunction::set_static_data_profile_counters_size( + &jit->static_data_, cpu_executable->hlo_profile_printer_data().profile_counters_size()); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 58808c76de6330a6b28e21dbdead03dea25847f6..78bc2c94425e00c2b26058daf609d71f1853664e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -93,7 +93,7 @@ TensorShape XlaOpKernelContext::InputShape(absl::string_view name) { } DataType XlaOpKernelContext::input_type(int index) const { - return context_->input(index).dtype(); + return context_->input_dtype(index); } DataType XlaOpKernelContext::InputType(absl::string_view name) { @@ -178,7 +178,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( // Converts an int32 or int64 scalar literal to an int64. static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, int64* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 0) { + if (literal.shape().rank() != 0) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::S32) { @@ -194,7 +194,7 @@ static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, // Converts an float32 or float64 scalar literal to a float64. static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, double* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 0) { + if (literal.shape().rank() != 0) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::F32) { @@ -228,8 +228,9 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { // Converts an int32 or int64 1D literal to an int64 vector. static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { - if (xla::ShapeUtil::Rank(literal.shape()) != 1) { - return errors::InvalidArgument("value is not 1D"); + if (literal.shape().rank() != 1) { + return errors::InvalidArgument("value is not 1D, rank: ", + literal.shape().rank()); } int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); if (literal.shape().element_type() == xla::S32) { @@ -353,8 +354,8 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); if (!variable->initialized()) { - return errors::InvalidArgument("Read of uninitialized variable ", - variable->name()); + return errors::FailedPrecondition("Read of uninitialized variable ", + variable->name()); } if (variable->type() != type) { return errors::InvalidArgument( @@ -456,6 +457,11 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { SetOutputExpression(index, XlaExpression::Constant(constant)); } +void XlaOpKernelContext::SetTensorListOutput(int index, + const xla::XlaOp& handle) { + SetOutputExpression(index, XlaExpression::TensorList(handle)); +} + void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { SetOutputExpression(index, XlaExpression::Resource(resource)); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 1858844bc05a6e12abbf07af83cad816590ddd03..e44415f60bff82fb92d0cf4ec81935564a2f083a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -168,6 +168,9 @@ class XlaOpKernelContext { // Returns an XlaExpression describing the value of 'index'. void SetOutputExpression(int index, const XlaExpression& expression); + // Sets output `index` to the Tensor List `handle`. + void SetTensorListOutput(int index, const xla::XlaOp& handle); + // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } Status status() { return context_->status(); } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 14237df69081016817fbd1a5332f22996e7f264d..26314034a18b2a77a3529f0c1af242e29ec69902 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -73,6 +73,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible allow_resource_types settings."; return false; } + if (x.allow_variant_types != y.allow_variant_types) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible allow_variant_types settings."; + return false; + } if (!x.has_device_whitelist && !y.has_device_whitelist) { LOG(WARNING) << "Duplicate registrations of " << x.name << "with no device whitelists."; @@ -289,6 +294,9 @@ void XlaOpRegistry::RegisterCompilationKernels() { if (op_registration->allow_resource_types) { allowed_values->add_type(DT_RESOURCE); } + if (op_registration->allow_variant_types) { + allowed_values->add_type(DT_VARIANT); + } // Don't build KernelDefs that have unsatisfiable type constraints. if (allowed_values->type().empty()) { unsatisfiable_type_constraint = true; @@ -485,6 +493,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowVariantTypes() { + registration_->allow_variant_types = true; + return *this; +} + XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( absl::string_view attr_name, DataType allowed) { std::set& types = diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 0bdd4a1085445420a5147756daac4a54f4725f11..c5e078a02d1ca6fdd8405ae6556a5205e387421e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -47,13 +47,14 @@ extern const char* const DEVICE_XLA_GPU; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; -constexpr std::array kNumericTypes = { +constexpr std::array kNumericTypes = { {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, - DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; + DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { +constexpr std::array kCpuAllTypes = { {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, - DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL}}; constexpr std::array kGpuAllTypes = { {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, @@ -211,6 +212,10 @@ class XlaOpRegistry { // allow DT_RESOURCE. bool allow_resource_types = false; + // Should we allow variant types for type attributes? Used by While to + // allow TensorList which is of type DT_VARIANT. + bool allow_variant_types = false; + // Mapping from attribute name to a list of supported types. std::unordered_map> type_constraints; @@ -232,9 +237,9 @@ class XlaOpRegistry { // Returns true if registrations x and y can both be added to the registry. // This is always the case if they refer to different ops. If they refer to - // the same op name, they must: have the same values for compilation_only and - // allow_resource_types; use a device_whitelist; and their - // whitelists must not intersect. + // the same op name, they must: have the same values for compilation_only, + // allow_resource_types and allow_variant_types; use a device_whitelist; and + // their whitelists must not intersect. static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); static Status CompileTimeConstantInputs(const NodeDef& node_def, @@ -292,6 +297,9 @@ class XlaOpRegistrationBuilder { // Allow DT_RESOURCE types for type parameters. XlaOpRegistrationBuilder& AllowResourceTypes(); + // Allow DT_VARIANT types for type parameters. + XlaOpRegistrationBuilder& AllowVariantTypes(); + // Mark 'input_name' as an argument whose value must be known at compile-time. XlaOpRegistrationBuilder& CompileTimeConstantInput( absl::string_view input_name); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 4360e0857964b0ac63fc887e269b04a4b00d854a..636e5ef721f58c009566c10a653d09a7667619c0 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -109,7 +109,7 @@ cc_library( name = "status_macros", srcs = ["status_macros.cc"], hdrs = ["status_macros.h"], - visibility = [":friends"], + visibility = ["//visibility:public"], deps = [ ":statusor", ":types", @@ -152,7 +152,7 @@ cc_library( ":status", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/stream_executor", + "//tensorflow/stream_executor/lib", ], ) @@ -224,6 +224,7 @@ cc_library( name = "shape_util", srcs = [ "index_util.cc", + "layout.cc", "layout_util.cc", "primitive_util.cc", "shape.cc", @@ -231,6 +232,7 @@ cc_library( ], hdrs = [ "index_util.h", + "layout.h", "layout_util.h", "primitive_util.h", "shape.h", @@ -290,6 +292,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "primitive_util_test", + srcs = ["primitive_util_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "layout_util_test", srcs = ["layout_util_test.cc"], @@ -301,6 +319,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "layout_test", + srcs = ["layout_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "index_util_test", srcs = ["index_util_test.cc"], @@ -575,6 +609,7 @@ cc_library( ":types", ":util", ":xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", @@ -682,6 +717,7 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -705,8 +741,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -790,6 +826,7 @@ cc_library( "debug_options_parsers.h", ], hdrs = ["debug_options_flags.h"], + visibility = [":friends"], deps = [ ":parse_flags_from_env", diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 58cc1575858201b4508d7340cb47e59c4f4c5783..529e7f77cec43f3158fcb59a53efa9a085d7422a 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -272,6 +272,15 @@ class Array { std::iota(&values_[0], &values_[0] + num_elements(), value); } + // Fills the array with a repeating sequence: + // [value, value + 1, ..., value + length - 1, value, ... ] + void FillRepeatedIota(const T& value, int64 length) { + for (int64 i = 0; i < num_elements(); i += length) { + std::iota(&values_[i], &values_[std::min(i + length, num_elements())], + value); + } + } + // Fills the array with the sequence i*multiplier for i=0,1,... void FillWithMultiples(const T& multiplier) { for (int64 i = 0; i < num_elements(); ++i) { @@ -280,11 +289,11 @@ class Array { } // Fills the array with random normal variables with the specified mean. - void FillRandom(const T& value, const double mean = 0.0, + void FillRandom(const T& stddev, const double mean = 0.0, const int seed = 12345) { std::mt19937 g(seed); std::normal_distribution distribution(mean, - static_cast(value)); + static_cast(stddev)); for (int64 i = 0; i < num_elements(); ++i) { values_[i] = static_cast(distribution(g)); } diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index fe99564d3c671cd7890e1fa26fcd2e3384972983..f5d56e8a9e1f3a05e1039f7cc90194407200f1ab 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -3,7 +3,7 @@ licenses(["notice"]) # Apache 2.0 -package(default_visibility = [":friends"]) +package(default_visibility = ["//visibility:public"]) package_group( name = "friends", @@ -170,6 +170,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", ], ) @@ -245,6 +246,7 @@ tf_cc_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 74b76f929949d3300a5d0ff45d5fa4cd9f162642..4f020bcec2756a328755d86ab04154d54f532465 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -186,7 +186,7 @@ StatusOr Client::ComputeConstant(const XlaComputation& computation, ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { - *request.mutable_output_layout() = *output_layout; + *request.mutable_output_layout() = output_layout->ToProto(); } ComputeConstantResponse response; @@ -278,53 +278,51 @@ StatusOr> Client::Execute( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { - if (execution_options != nullptr && - execution_options->device_handles_size() > 1) { - std::vector computation_instances = { - XlaComputationInstance{ - computation, - std::vector(arguments.begin(), arguments.end()), - *execution_options, execution_profile}}; - TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances)); - // The result selection is a bit hacky, but better than assuming it is - // device 0. - // - // TODO(b/118493728): Allow Execute to return one result per computation. - for (int64 i = 0; i < results.size(); i++) { - TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); - if (!ShapeUtil::IsEmptyTuple(shape)) { - VLOG(3) << "Fetching result from device " << i << ": " - << ShapeUtil::HumanString(shape); - return std::move(results[i]); - } + // Create an ExecutionOptions if necessary, or set its DeviceHandles. + absl::optional options_storage; + if (!execution_options || execution_options->device_handles().empty()) { + if (execution_options) { + options_storage.emplace(*execution_options); + } else { + options_storage.emplace(CreateDefaultExecutionOptions()); } - TF_RET_CHECK(!results.empty()); - VLOG(1) << "Defaulting to device 0 result"; - return std::move(results[0]); - } - - // The argument shapes affect how the computation is compiled. - std::vector arg_shapes(arguments.size()); - for (int i = 0; i < arguments.size(); i++) { - TF_ASSIGN_OR_RETURN(arg_shapes[i], GetShape(*arguments[i])); - } - - TF_ASSIGN_OR_RETURN(auto handle, - Compile(computation, arg_shapes, execution_options)); - - TF_ASSIGN_OR_RETURN(auto result, - Execute(handle, arguments, execution_profile)); - - if (execution_profile != nullptr) { - if (VLOG_IS_ON(1)) { - TF_ASSIGN_OR_RETURN( - auto execution_stats, - ExecutionStatsAsString(computation, *execution_profile)); - VLOG(1) << execution_stats; + execution_options = &*options_storage; + + TF_ASSIGN_OR_RETURN(auto device_handles, + GetDeviceHandles(/*device_count=*/1)); + TF_RET_CHECK(!device_handles.empty()); + *options_storage->add_device_handles() = std::move(device_handles[0]); + } + + std::vector computation_instances = { + XlaComputationInstance{ + computation, + std::vector(arguments.begin(), arguments.end()), + *execution_options, execution_profile}}; + + // Instead of invoking Compile() and Execute(), invoke + // Service::ExecuteParallel() to execute our one computation. Compile() + // caches the executable forever, which isn't what we want. + VLOG(1) << "Making ExecuteParallel request: " + << execution_options->DebugString(); + TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances)); + VLOG(1) << "ExecuteParallel request done."; + + // The result selection is a bit hacky, but better than assuming it is + // device 0. + // + // TODO(b/118493728): Allow Execute to return one result per computation. + for (int64 i = 0; i < results.size(); i++) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); + if (!ShapeUtil::IsEmptyTuple(shape)) { + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return std::move(results[i]); } } - - return std::move(result); + TF_RET_CHECK(!results.empty()); + VLOG(1) << "Defaulting to device 0 result"; + return std::move(results[0]); } StatusOr>> Client::ExecuteParallel( diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index d0ac4703c632e0e01d3c8911594b46fedf28930d..eff8713ac340e82ee7633f1f078334ba73b67b2f 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -52,6 +52,12 @@ class Client { // need to live beyond this call.) // * If execution_options.device_handles should be empty. If you need // non-empty device handles, call 'Execute' instead. + // + // TODO(b/122731460): This call caches the resulting Executable in the Service + // *forever*. If you're only going to run the computation once, you may want + // to call the Execute(const XlaComputation&) overload. If you're going to + // run the computation more than once but you want control over when the + // Executable is unloaded, use the LocalClient API. StatusOr Compile( const XlaComputation& computation, absl::Span argument_shapes, @@ -76,6 +82,10 @@ class Client { // device is chosen by the service. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. + // + // TODO(b/122731460): The given computation is compiled and then thrown away + // immediately after it's run. If you want control over how long the + // resulting Executable lives, use the LocalClient API. StatusOr> Execute( const XlaComputation& computation, absl::Span arguments, diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 27b7fa7b29206affa9f9c2e4becd9e4ea66484ab..42aae026229a49fd801cc90562fa51f604336148 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -24,12 +24,14 @@ limitations under the License. namespace xla { -LocalClientOptions::LocalClientOptions(se::Platform* platform, - int number_of_replicas, - int intra_op_parallelism_threads) +LocalClientOptions::LocalClientOptions( + se::Platform* platform, int number_of_replicas, + int intra_op_parallelism_threads, + const absl::optional>& allowed_devices) : platform_(platform), number_of_replicas_(number_of_replicas), - intra_op_parallelism_threads_(intra_op_parallelism_threads) {} + intra_op_parallelism_threads_(intra_op_parallelism_threads), + allowed_devices_(allowed_devices) {} LocalClientOptions& LocalClientOptions::set_platform(se::Platform* platform) { platform_ = platform; @@ -58,6 +60,17 @@ int LocalClientOptions::intra_op_parallelism_threads() const { return intra_op_parallelism_threads_; } +LocalClientOptions& LocalClientOptions::set_allowed_devices( + const absl::optional>& allowed_devices) { + allowed_devices_ = allowed_devices; + return *this; +} + +const absl::optional>& LocalClientOptions::allowed_devices() + const { + return allowed_devices_; +} + /* static */ ClientLibrary& ClientLibrary::Singleton() { static ClientLibrary* c = new ClientLibrary; return *c; @@ -67,9 +80,10 @@ ClientLibrary::ClientLibrary() = default; ClientLibrary::~ClientLibrary() = default; /* static */ StatusOr ClientLibrary::GetOrCreateLocalClient( - se::Platform* platform) { + se::Platform* platform, const absl::optional>& device_set) { LocalClientOptions default_options; default_options.set_platform(platform); + default_options.set_allowed_devices(device_set); return GetOrCreateLocalClient(default_options); } @@ -94,7 +108,7 @@ ClientLibrary::~ClientLibrary() = default; service_options.set_number_of_replicas(replica_count); service_options.set_intra_op_parallelism_threads( options.intra_op_parallelism_threads()); - + service_options.set_allowed_devices(options.allowed_devices()); auto instance = absl::make_unique(); TF_ASSIGN_OR_RETURN(instance->service, LocalService::NewService(service_options)); diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index 3ad558fa532931937fab898f7b855f0a3370eaec..62d225c6c298b26bbbd248fc1f4be64fc8efcf6b 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -23,9 +23,11 @@ limitations under the License. #include #include +#include #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" @@ -43,9 +45,10 @@ namespace xla { // Options to configure the local client when it is created. class LocalClientOptions { public: - LocalClientOptions(se::Platform* platform = nullptr, - int number_of_replicas = 1, - int intra_op_parallelism_threads = -1); + LocalClientOptions( + se::Platform* platform = nullptr, int number_of_replicas = 1, + int intra_op_parallelism_threads = -1, + const absl::optional>& allowed_devices = absl::nullopt); // Set the platform backing the service, or nullptr for the default platform. LocalClientOptions& set_platform(se::Platform* platform); @@ -60,10 +63,17 @@ class LocalClientOptions { LocalClientOptions& set_intra_op_parallelism_threads(int num_threads); int intra_op_parallelism_threads() const; + // Sets the allowed_devices set for selectively constructing stream executors + // on the platform. + LocalClientOptions& set_allowed_devices( + const absl::optional>& allowed_devices); + const absl::optional>& allowed_devices() const; + private: se::Platform* platform_; int number_of_replicas_; int intra_op_parallelism_threads_; + absl::optional> allowed_devices_; }; class ClientLibrary { @@ -73,8 +83,11 @@ class ClientLibrary { // // platform : The platform the underlying XLA service should target. If // null then default platform is used. + // device_set: Set of device IDs for which the stream executor will be + // created, for the given platform. static StatusOr GetOrCreateLocalClient( - se::Platform* platform = nullptr); + se::Platform* platform = nullptr, + const absl::optional>& allowed_devices = absl::nullopt); static StatusOr GetOrCreateLocalClient( const LocalClientOptions& options); diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 1f594e551af381d7537e947892cbf7e0b5b3b861..ec0e08975926f36c36c854f83a40b374b12a09a4 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -58,6 +58,12 @@ const Shape* ExecutableBuildOptions::result_layout() const { return result_layout_set_ ? &result_layout_ : nullptr; } +ExecutableBuildOptions& ExecutableBuildOptions::set_num_replicas( + int num_replicas) { + num_replicas_ = num_replicas; + return *this; +} + string ExecutableBuildOptions::ToString() const { string result_layout = "nullopt"; if (result_layout_set_) { @@ -65,8 +71,9 @@ string ExecutableBuildOptions::ToString() const { } return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " - "generate_hlo_graph=%s}", - device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph()); + "generate_hlo_graph=%s, num_replicas=%d}", + device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph(), + num_replicas_); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index a58090253bfac7779e4b61bc7231a0f0d945cc00..1d85fb34304b95d1fccdb0b0d6a7a65e739fae18 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -67,12 +67,18 @@ class ExecutableBuildOptions { // debugging. string ToString() const; + // The number of replicas of this computation that are to be executed. + // Defaults to 1. + int num_replicas() const { return num_replicas_; } + ExecutableBuildOptions& set_num_replicas(int num_replicas); + private: int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; absl::optional debug_options_; DeviceMemoryAllocator* device_allocator_ = nullptr; + int num_replicas_ = 1; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index c5733bc66deb8d55a9186ad1893abaf17ed6909e..b30ab84240286fe4eb145fc893ba3f3f7ab26d00 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -1,5 +1,7 @@ # Common computation builders for XLA. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") + licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow/compiler/xla/client:friends"]) @@ -13,9 +15,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") - # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -35,6 +34,96 @@ cc_library( ], ) +xla_test( + name = "arithmetic_test", + srcs = ["arithmetic_test.cc"], + deps = [ + ":arithmetic", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "cholesky", + srcs = ["cholesky.cc"], + hdrs = ["cholesky.h"], + deps = [ + ":math", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/compiler/xla/client/lib:triangular_solve", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "cholesky_test", + srcs = ["cholesky_test.cc"], + tags = ["optonly"], + deps = [ + ":arithmetic", + ":cholesky", + ":matrix", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "comparators", + srcs = ["comparators.cc"], + hdrs = ["comparators.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "comparators_test", + srcs = ["comparators_test.cc"], + deps = [ + ":comparators", + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:inlined_vector", + ], +) + cc_library( name = "constants", srcs = ["constants.cc"], @@ -52,7 +141,6 @@ cc_library( xla_test( name = "constants_test", srcs = ["constants_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":constants", "//tensorflow/compiler/xla:test", @@ -75,11 +163,28 @@ cc_library( ], ) +cc_library( + name = "loops", + srcs = ["loops.cc"], + hdrs = ["loops.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "math", srcs = ["math.cc"], hdrs = ["math.h"], deps = [ + ":arithmetic", ":constants", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -90,7 +195,23 @@ cc_library( xla_test( name = "math_test", srcs = ["math_test.cc"], - tags = ["enable_for_xla_interpreter"], + deps = [ + ":math", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +xla_test( + name = "math_exhaustive_test", + srcs = ["math_exhaustive_test.cc"], + shard_count = 16, deps = [ ":math", "//tensorflow/compiler/xla:literal_util", @@ -104,31 +225,43 @@ xla_test( ) cc_library( - name = "numeric", - srcs = ["numeric.cc"], - hdrs = ["numeric.h"], + name = "matrix", + srcs = ["matrix.cc"], + hdrs = ["matrix.h"], deps = [ ":arithmetic", ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) xla_test( - name = "numeric_test", - srcs = ["numeric_test.cc"], - tags = ["enable_for_xla_interpreter"], + name = "matrix_test", + srcs = ["matrix_test.cc"], deps = [ - ":numeric", + ":matrix", + ":slicing", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", ], ) @@ -167,11 +300,77 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", "@com_google_absl//absl/base", ], ) +cc_library( + name = "qr", + srcs = ["qr.cc"], + hdrs = ["qr.h"], + deps = [ + ":arithmetic", + ":constants", + ":loops", + ":math", + ":matrix", + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "qr_test", + srcs = ["qr_test.cc"], + tags = ["optonly"], + deps = [ + ":matrix", + ":qr", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "slicing", + srcs = ["slicing.cc"], + hdrs = ["slicing.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "slicing_test", + srcs = ["slicing_test.cc"], + deps = [ + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "sorting", srcs = ["sorting.cc"], @@ -188,13 +387,42 @@ cc_library( xla_test( name = "sorting_test", srcs = ["sorting_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ ":sorting", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "quantize", + hdrs = ["quantize.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "quantize_test", + srcs = ["quantize_test.cc"], + # TODO(b/122119490): re-enable TAP after fixing. + tags = [ + "notap", + ], + deps = [ + ":quantize", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -221,3 +449,48 @@ cc_library( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "triangular_solve", + srcs = ["triangular_solve.cc"], + hdrs = ["triangular_solve.h"], + deps = [ + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "triangular_solve_test", + srcs = ["triangular_solve_test.cc"], + tags = [ + "enable_for_xla_interpreter", + "noasan", # sometimes times out, http://b/78650012 + ], + deps = [ + ":math", + ":matrix", + ":triangular_solve", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index e86c10f030f3990d67e5a6638100640f73c82307..3b875135af29f142463ffd783bfeaadc61ada1af 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -117,10 +117,70 @@ XlaOp Any(XlaOp predicates) { XlaComputation logical_or = CreateScalarOrComputation(PRED, builder); TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); + std::vector all_dimensions(predicates_shape.rank()); std::iota(all_dimensions.begin(), all_dimensions.end(), 0); return Reduce(predicates, f, logical_or, all_dimensions); }); } +namespace { + +XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) { + XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + XlaOp init_value; + XlaComputation reducer; + if (is_min) { + init_value = MaxValue(builder, input_shape.element_type()); + reducer = CreateScalarMinComputation(input_shape.element_type(), builder); + } else { + init_value = MinValue(builder, input_shape.element_type()); + reducer = CreateScalarMaxComputation(input_shape.element_type(), builder); + } + + XlaOp input_max = Reduce(input, init_value, reducer, + /*dimensions_to_reduce=*/{axis}); + std::vector broadcast_dims(input_shape.rank() - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + // Compute a mask that has 1s for elements equal to the maximum. + XlaOp partial_mask = + ConvertElementType(Eq(input, input_max, broadcast_dims), output_type); + + // In order to make identity elements for a bitwise And, we: + // Left shift the 1 to the leftmost bit, yielding 0x10...0 + // Arithmetic right shift the 1 back to the rightmost bit, yielding + // 0xFF...F + int32 bits_in_type = + ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; + XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type); + XlaOp full_mask = ShiftRightArithmetic( + ShiftLeft(partial_mask, shift_amount), shift_amount); + + // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its + // index. + + const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis); + XlaOp iota = Iota(builder, output_type, axis_size); + XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis}); + + // If there are multiple maximum elements, choose the one with the highest + // index. + return Reduce(product, MinValue(builder, output_type), + CreateScalarMaxComputation(output_type, builder), + /*dimensions_to_reduce=*/{axis}); + }); +} + +} // namespace + +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/false); +} + +XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/true); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 632e8cc8bc64fad236a0226c6e93079aadde7050..d4a7812c441c351b121e5d72faf9642b06728b18 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -57,6 +57,14 @@ XlaComputation CreateScalarOrComputation(PrimitiveType type, // Note: if predicates is zero-sized, Any() vacuously returns false. XlaOp Any(XlaOp predicates); +// Returns the argmax of `input` along `axis`. `output_type` is the type to +// use for the output. +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis); + +// Returns the argmin of `input` along `axis`. `output_type` is the type to +// use for the output. +XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a13839f9db89b9c07f2465867a503ef2193f8160 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using ArithmeticTest = ClientLibraryTestBase; + +XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMin(x, S32, /*axis=*/0); + + std::vector expected = {0, 2, 2}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMin(x, S32, /*axis=*/1); + + std::vector expected = {0, 1, 2}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMax(x, S32, /*axis=*/0); + + std::vector expected = {2, 0, 1}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) { + XlaBuilder builder(TestName()); + auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); + ArgMax(x, S32, /*axis=*/1); + + std::vector expected = {1, 0, 0}; + ComputeAndCompareR1(&builder, expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc similarity index 55% rename from tensorflow/compiler/tf2xla/lib/cholesky.cc rename to tensorflow/compiler/xla/client/lib/cholesky.cc index ab3d0a566839343828d176d9a46672824e425613..414bd1494cd32f32a5c37e84119de930678a776b 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/xla/client/lib/cholesky.cc @@ -13,24 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/cholesky.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" -namespace tensorflow { +namespace xla { namespace { @@ -49,86 +51,72 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::XlaOp CholeskyUnblocked(xla::XlaOp a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(a_shape); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - auto major_dims = xla::AsInt64Slice(a_shape.dimensions()) +XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int n_dims = a_shape.rank(); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + auto major_dims = AsInt64Slice(a_shape.dimensions()) .subspan( /*pos=*/0, /*len=*/n_dims - 2); - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::Shape col_shape; - xla::Shape row_shape; - for (int64 d : major_dims) { - row_shape.add_dimensions(d); - col_shape.add_dimensions(d); - } - row_shape.add_dimensions(1); - row_shape.add_dimensions(n); - row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = xla::Zeros(body_builder, row_shape); - - col_shape.add_dimensions(n); - col_shape.add_dimensions(1); - col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = xla::Zeros(body_builder, col_shape); - - std::vector mask_vector(n); - std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = xla::ConstantR1(body_builder, mask_vector); + auto body_fn = + [&](XlaOp i, absl::Span loop_vars, + XlaBuilder* body_builder) -> StatusOr> { + std::vector row_shape_dims(major_dims.begin(), major_dims.end()); + std::vector col_shape_dims(major_dims.begin(), major_dims.end()); + row_shape_dims.push_back(1); + row_shape_dims.push_back(n); + auto mask_zeros_row = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), row_shape_dims)); + + col_shape_dims.push_back(n); + col_shape_dims.push_back(1); + auto mask_zeros_col = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), col_shape_dims)); + auto mask_range_row = - xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + Iota(body_builder, ShapeUtil::MakeShape(S32, row_shape_dims), + /*iota_dimension=*/n_dims - 1); auto mask_range_col = - xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); + Iota(body_builder, ShapeUtil::MakeShape(S32, col_shape_dims), + /*iota_dimension=*/n_dims - 2); auto body_a = loop_vars[0]; auto body_l = loop_vars[1]; // row = l[..., i, :i] // select the whole i-th row, then mask out all columns past i-1 - auto zero = xla::ConstantR0(body_builder, 0); + auto zero = ConstantR0(body_builder, 0); auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); - auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); + auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i); // a[..., i, i] auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); // np.dot(row, np.swapaxes(row, -1, -2)) - auto diag_dot = BatchDot(row, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) - auto l_ii = - xla::Pow(a_ii - diag_dot, - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + auto l_ii = Sqrt(a_ii - diag_dot); // a[..., i+1:, i] // select the whole i-th column, then mask out all rows above i+1 auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); - auto a_ip1i = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); + auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i); // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / // l[..., i, i] // The columns in [i, n] are zeroed out in `row`, so we just have to // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], // r.T) - auto dot = BatchDot(body_l, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = - xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); + auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot); body_l = DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); @@ -136,12 +124,12 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // column assign will wrap around and overwrite the diagonal assign. body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); - return std::vector{body_a, body_l}; + return std::vector{body_a, body_l}; }; TF_ASSIGN_OR_RETURN( auto cholesky_while, - XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + ForEachIndex(n, S32, body_fn, {a, l}, "unblocked", builder)); return cholesky_while[1]; }); @@ -149,34 +137,41 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, } // namespace -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(a_shape); +XlaOp Cholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int ndims = a_shape.rank(); if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to Cholesky must have rank >= 2: ", ndims); + return InvalidArgument( + "Argument to Cholesky must have rank >= 2; shape was %s", + a_shape.ToString()); + } + + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + if (n != ShapeUtil::GetDimension(a_shape, -2)) { + return InvalidArgument( + "Argument to Cholesky must be batched square matrices; got shape %s", + ShapeUtil::HumanString(a_shape)); } - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); + if (primitive_util::IsComplexType(a_shape.element_type())) { + return Unimplemented( + "Complex types are not implemented in Cholesky; got shape %s", + ShapeUtil::HumanString(a_shape)); } if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to Cholesky must be >= 1; got ", block_size); + return InvalidArgument( + "block_size argument to Cholesky must be >= 1; got %d", block_size); } // Blocked left-looking Cholesky factorization. // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = xla::ZerosLike(a); + XlaOp l = ZerosLike(a); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); if (i > 0) { @@ -185,9 +180,7 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); - auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto delta = BatchDot(lhs, TransposeInMinorDims(rhs), precision); auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); a = UpdateSliceInMinorDims(a, before - delta, {i, i}); } @@ -214,4 +207,4 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, }); } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/xla/client/lib/cholesky.h similarity index 87% rename from tensorflow/compiler/tf2xla/lib/cholesky.h rename to tensorflow/compiler/xla/client/lib/cholesky.h index 9a561c34b92ee45059f2a05336e682838f8e36e2..0bae26837c0f14dd0cfab82cf426becc787ec11c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/xla/client/lib/cholesky.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Computes the Cholesky decompositions of a batch of symmetric positive // definite matrices. @@ -34,6 +34,6 @@ xla::XlaOp Cholesky( xla::XlaOp a, int64 block_size = 256, xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ diff --git a/tensorflow/compiler/xla/client/lib/cholesky_test.cc b/tensorflow/compiler/xla/client/lib/cholesky_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..095dd4fbf8b7c90047c4428b50c626c16e9c1e94 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/cholesky_test.cc @@ -0,0 +1,166 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/cholesky.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using xla::int64; + +using CholeskyTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(CholeskyTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a, /*block_size=*/2); + + xla::Array2D expected({ + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, Simple2) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array2D expected( + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + xla::Cholesky(a); + + xla::Array3D expected({ + { + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}, + }); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +using CholeskyTestCase = std::tuple; + +class RandomCholeskyTest + : public xla::ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(RandomCholeskyTest, Random) { + xla::XlaBuilder builder(TestName()); + + auto test_params = GetParam(); + std::vector dimensions = {std::get<0>(test_params), + std::get<1>(test_params), + std::get<1>(test_params)}; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions); + TF_ASSERT_OK_AND_ASSIGN( + auto literal, + xla::LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + + auto input = xla::Parameter(&builder, 0, shape, "input"); + // Form a random positive definite matrix. + auto matrix = xla::BatchDot(input, TransposeInMinorDims(input), + xla::PrecisionConfig::HIGHEST); + + auto cholesky = xla::Cholesky(matrix, /*block_size=*/4); + + // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0 + auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky), + xla::PrecisionConfig::HIGHEST); + auto delta = matrix - verification; + xla::Reduce(delta * delta, xla::ConstantR0(&builder, 0.0), + CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2}); + + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); + ComputeAndCompareR0(&builder, 0.0, {input_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest, + ::testing::Values(CholeskyTestCase{1, 1}, + CholeskyTestCase{1, 2}, + CholeskyTestCase{10, 5}, + CholeskyTestCase{2, 20})); + +} // namespace diff --git a/tensorflow/compiler/xla/client/lib/comparators.cc b/tensorflow/compiler/xla/client/lib/comparators.cc new file mode 100644 index 0000000000000000000000000000000000000000..c620c9841a5146618e3a142adeb3fe2da525950a --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/comparators.cc @@ -0,0 +1,159 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/comparators.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using XlaOpGenerator = XlaOp (*)(const XlaOp&, const XlaOp&, + absl::Span); + +XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value, + int64 bit_width) { + PrimitiveType signed_type; + PrimitiveType unsigned_type; + XlaOp max_value; + switch (bit_width) { + case 16: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S16; + unsigned_type = U16; + break; + case 32: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S32; + unsigned_type = U32; + break; + case 64: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S64; + unsigned_type = U64; + break; + default: + return value.builder()->ReportError( + InvalidArgument("Invalid bit width %lld for Comparator floating " + "point parameter.", + bit_width)); + } + // Switch from a floating point value to a integer value in such a way that + // when using the integer value to compare, we get the same result for normal + // values, and -Nan is treated as the smallest value, and Nan is treated as + // the largest value. + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? numeric_limits::max() - x : x; + // then y is ordered as an int32 such that finite values have the obvious + // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning + // and end of the ordering. + // Note that in order to avoid -x to overflow, we calculate + // numeric_limits::max() - x as unsigned, and then convert back to + // signed. + auto signed_value = BitcastConvertType(value, signed_type); + auto unsigned_value = BitcastConvertType(value, unsigned_type); + auto flipped_value = + BitcastConvertType(Sub(max_value, unsigned_value), signed_type); + auto is_negative = Lt(signed_value, Zero(value.builder(), signed_type)); + return Select(is_negative, flipped_value, signed_value); +} + +XlaComputation CreateScalarComparisonComputation( + const string& name, const std::vector& operand_types, + XlaBuilder* builder, XlaOpGenerator generator) { + // Create a default computation where we compare only the first two + // parameters of type 'operand_types[0]'. + auto b = builder->CreateSubBuilder(name); + if (operand_types.empty()) { + b->ReportError(InvalidArgument("operand_types should not be empty")); + return b->BuildAndNoteError(); + } + + int64 parameter_count = 0; + XlaOp first_lhs_param; + XlaOp first_rhs_param; + + // For each type in 'operand_types' we create two parameters of this type. The + // idea is that this computation can be used by n-ary Sort, and potentially + // should support comparing also the other operands of sort. In this default + // computation, however, we will not actually use any parameters except the + // first two. + for (auto operand_type : operand_types) { + auto scalar_shape = ShapeUtil::MakeShape(operand_type, {}); + auto lhs_param = Parameter(b.get(), parameter_count * 2, scalar_shape, + absl::StrCat("p.", parameter_count, ".lhs")); + auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape, + absl::StrCat("p.", parameter_count, ".rhs")); + if (parameter_count == 0) { + first_lhs_param = lhs_param; + first_rhs_param = rhs_param; + } + ++parameter_count; + } + if (primitive_util::IsFloatingPointType(operand_types[0])) { + PrimitiveType compare_type = operand_types[0]; + // Special-case handling for BF16. We currently do not support direct + // comparisons with BF16, so we convert to F32 and then use the F32 + // comparison logic. + if (compare_type == BF16) { + compare_type = F32; + first_lhs_param = ConvertElementType(first_lhs_param, F32); + first_rhs_param = ConvertElementType(first_rhs_param, F32); + } + int64 bit_width = primitive_util::BitWidth(compare_type); + first_lhs_param = + BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width); + first_rhs_param = + BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width); + } + generator(first_lhs_param, first_rhs_param, {}); + return b->BuildAndNoteError(); +} +} // namespace + +// Creates a scalar less-than computation and returns it. +XlaComputation CreateScalarLtComputation( + const std::vector& operand_types, XlaBuilder* builder) { + return CreateScalarComparisonComputation("compare-less-than", operand_types, + builder, Lt); +} + +// Creates a scalar greater-than computation and returns it. +XlaComputation CreateScalarGtComputation( + const std::vector& operand_types, XlaBuilder* builder) { + return CreateScalarComparisonComputation("compare-greater-than", + operand_types, builder, Gt); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/comparators.h b/tensorflow/compiler/xla/client/lib/comparators.h new file mode 100644 index 0000000000000000000000000000000000000000..cbcfc227dd495537f59bf0a9090bad8ade15da62 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/comparators.h @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Creates a scalar less-than computation and returns it. The created +// computation has 2 * 'operand_types.size()' many parameters, where parameters +// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The +// computation compares the first two parameters. For floating point types, a +// total order is created where +// -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN +XlaComputation CreateScalarLtComputation( + const std::vector& operand_types, XlaBuilder* builder); + +// Creates a scalar greater-than computation and returns it. The created +// computation has 2 * 'operand_types.size()' many parameters, where parameters +// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The +// computation compares the first two parameters. For floating point types, a +// total order is created where +// NaN > infinity > ... > 0 > -0 > ... > -infinity > -NaN +XlaComputation CreateScalarGtComputation( + const std::vector& operand_types, XlaBuilder* builder); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ diff --git a/tensorflow/compiler/xla/client/lib/comparators_test.cc b/tensorflow/compiler/xla/client/lib/comparators_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..598956803b34702b1e095a342648d348fa350b29 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/comparators_test.cc @@ -0,0 +1,149 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/comparators.h" + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class ComparatorsTest : public ClientLibraryTestBase { + public: + ComparatorsTest() : builder_(TestName()) {} + XlaBuilder* builder() { return &builder_; } + + private: + XlaBuilder builder_; +}; + +template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> +void BuildComparatorAndComparisons(ComparatorsTest* test, + bool compare_less_than, + absl::InlinedVector* expected) { + auto compare = compare_less_than + ? CreateScalarLtComputation({type}, test->builder()) + : CreateScalarGtComputation({type}, test->builder()); + + auto negative_nan = ConstantR0( + test->builder(), -T(std::numeric_limits::quiet_NaN())); + auto positive_nan = ConstantR0(test->builder(), + T(std::numeric_limits::quiet_NaN())); + auto negative_zero = ConstantR0(test->builder(), T(-0.)); + auto positive_zero = ConstantR0(test->builder(), T(0.)); + auto negative_infinity = MinValue(test->builder(), type); + auto positive_infinity = MaxValue(test->builder(), type); + + // List the values in the expected sorting order from smallest to largest. + std::vector all_constants{negative_nan, negative_infinity, + negative_zero, positive_zero, + positive_infinity, positive_nan}; + + // Do pairwise comparisons. + std::vector all_comparisons; + for (const XlaOp& lhs_constant : all_constants) { + for (const XlaOp& rhs_constant : all_constants) { + all_comparisons.push_back(Broadcast( + Call(test->builder(), compare, {lhs_constant, rhs_constant}), {1})); + } + } + + // Concantenate the comparison results. + ConcatInDim(test->builder(), all_comparisons, 0); + + // If we use less-than comparisons, we expect the comparison to result in true + // if the lhs value to be compared appears earlier in 'all_constants' than the + // rhs value. Likewise, if we use greater-than comparisons, we expect the + // comparison to return true if the rhs value appears earlier in + // 'all_constants' than the lhs value. + expected->clear(); + for (int i = 0; i < all_constants.size(); ++i) { + for (int j = 0; j < all_constants.size(); ++j) { + expected->push_back(compare_less_than ? i < j : i > j); + } + } +} + +XLA_TEST_F(ComparatorsTest, CompareLtBF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtBF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareLtF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtF16) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareLtF32) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtF32) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareLtF64) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/true, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +XLA_TEST_F(ComparatorsTest, CompareGtF64) { + absl::InlinedVector expected; + BuildComparatorAndComparisons(this, /*compare_less_than=*/false, + &expected); + ComputeAndCompareR1(builder(), expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 81624614c1e3599dfe116eb61d9e2edcd5230684..4e5310a380e8bda15348dae2cbb0ea9e2c381bcb 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -56,6 +56,8 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { return ConstantR0(builder, static_cast(value)); case C64: return ConstantR0(builder, static_cast(value)); + case C128: + return ConstantR0(builder, static_cast(value)); case U8: return ConstantR0(builder, static_cast(value)); case U32: @@ -88,6 +90,27 @@ XlaOp ScalarLike(XlaOp prototype, T value) { }); } +// Returns an array or scalar containing copies of `value` cast to the same +// run-type type as `prototype` and broadcast to the same dimensions as +// `prototype`. +// +// If `prototype` is not a scalar or array, returns an error. +template +XlaOp FullLike(XlaOp prototype, T value) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + if (ShapeUtil::IsScalar(shape) || shape.IsArray()) { + return Broadcast(ScalarLike(prototype, value), shape.dimensions()); + } else { + return InvalidArgument( + "Prototype shape for BroadcastConstantLike must be a scalar or " + "array, but was %s", + shape.ToString()); + } + }); +} + // Returns a scalar with value '0' of 'type'. XlaOp Zero(XlaBuilder* builder, PrimitiveType type); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/xla/client/lib/loops.cc similarity index 50% rename from tensorflow/compiler/tf2xla/lib/while_loop.cc rename to tensorflow/compiler/xla/client/lib/loops.cc index 594ab1dfd0700f47501712183f6efe62d17e15e7..721f987628a8ac7da3f3f872939c3f0457d6bbe2 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/xla/client/lib/loops.cc @@ -13,44 +13,43 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" + +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -namespace tensorflow { +namespace xla { -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { int arity = initial_values.size(); - std::vector var_shapes; + std::vector var_shapes; var_shapes.reserve(arity); - for (const xla::XlaOp& input : initial_values) { + for (const XlaOp& input : initial_values) { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); var_shapes.push_back(std::move(shape)); } - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); + Shape tuple_shape = ShapeUtil::MakeTupleShape(var_shapes); // Unpacks a tuple into its component parts. - auto unpack_tuple = [](xla::XlaOp tuple, int arity, - xla::XlaBuilder* builder) { - std::vector elements(arity); + auto unpack_tuple = [](XlaOp tuple, int arity, XlaBuilder* builder) { + std::vector elements(arity); for (int i = 0; i < arity; ++i) { - elements[i] = xla::GetTupleElement(tuple, i); + elements[i] = GetTupleElement(tuple, i); } return elements; }; // Build the condition. - std::unique_ptr cond_builder = + std::unique_ptr cond_builder = builder->CreateSubBuilder(absl::StrCat(name, "_condition")); { - auto parameter = - xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), @@ -60,11 +59,10 @@ xla::StatusOr> XlaWhileLoop( TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); // Build the body. - std::unique_ptr body_builder = + std::unique_ptr body_builder = builder->CreateSubBuilder(absl::StrCat(name, "_body")); { - auto parameter = - xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); + auto parameter = Parameter(body_builder.get(), 0, tuple_shape, "parameter"); TF_ASSIGN_OR_RETURN( auto result, @@ -72,56 +70,54 @@ xla::StatusOr> XlaWhileLoop( body_builder.get())); TF_RET_CHECK(result.size() == initial_values.size()); - xla::Tuple(body_builder.get(), result); + Tuple(body_builder.get(), result); } TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); - auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values)); + auto outputs = While(cond, body, Tuple(builder, initial_values)); return unpack_tuple(outputs, arity, builder); } -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder) { - auto while_cond_fn = - [&](absl::Span values, - xla::XlaBuilder* cond_builder) -> xla::StatusOr { - return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, - num_iterations)); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder) { + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + return Lt(values[0], ConstantR0WithType(cond_builder, num_iterations_type, + num_iterations)); }; - auto while_body_fn = [&](absl::Span values, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::XlaOp iteration = values[0]; + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + XlaOp iteration = values[0]; - std::vector updated_values; + std::vector updated_values; updated_values.reserve(values.size()); - updated_values.push_back(xla::Add( + updated_values.push_back(Add( iteration, - xla::ConstantLiteral(body_builder, - xla::LiteralUtil::One(num_iterations_type)))); + ConstantLiteral(body_builder, LiteralUtil::One(num_iterations_type)))); values.remove_prefix(1); - TF_ASSIGN_OR_RETURN(std::vector body_outputs, + TF_ASSIGN_OR_RETURN(std::vector body_outputs, body_function(iteration, values, body_builder)); updated_values.insert(updated_values.end(), body_outputs.begin(), body_outputs.end()); return updated_values; }; - std::vector values; + std::vector values; values.reserve(initial_values.size() + 1); - values.push_back(xla::ConstantLiteral( - builder, xla::LiteralUtil::Zero(num_iterations_type))); + values.push_back( + ConstantLiteral(builder, LiteralUtil::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); - TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, - name, builder)); + TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, + values, name, builder)); values.erase(values.begin(), values.begin() + 1); return values; } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/xla/client/lib/loops.h similarity index 62% rename from tensorflow/compiler/tf2xla/lib/while_loop.h rename to tensorflow/compiler/xla/client/lib/loops.h index f2134bb4495a12b8342961d96f70e7737f816c7d..e11de59493e9c1de51fbdb6c45dab6d82b85a62a 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/xla/client/lib/loops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ #include #include @@ -25,19 +25,18 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -namespace tensorflow { +namespace xla { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function(absl::Span, - xla::XlaBuilder*)> - LoopConditionFunction; +typedef std::function(absl::Span, XlaBuilder*)> + WhileLoopHelperConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. -typedef std::function>( - absl::Span, xla::XlaBuilder*)> - LoopBodyFunction; +typedef std::function>(absl::Span, + XlaBuilder*)> + WhileLoopHelperBodyFunction; // Helper function for building an XLA while loop, where the values carried by // the loop are a tuple of values, e.g., (a, b, c): @@ -47,27 +46,27 @@ typedef std::function>( // init: (a, b, c) // ) // 'name' is a descriptive name for the loop. -xla::StatusOr> XlaWhileLoop( - const LoopConditionFunction& condition_function, - const LoopBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); +StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. // // The body function (ForEachIndexBodyFunction) takes as input a pair of // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. -typedef std::function>( - xla::XlaOp, absl::Span, xla::XlaBuilder*)> +typedef std::function>( + XlaOp, absl::Span, XlaBuilder*)> ForEachIndexBodyFunction; -xla::StatusOr> XlaForEachIndex( - int64 num_iterations, xla::PrimitiveType num_iterations_type, +StatusOr> ForEachIndex( + int64 num_iterations, PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - xla::XlaBuilder* builder); + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOOPS_H_ diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 08a887a6e4660cb2528f0ec7244b7ccc540808d2..14891206855725f1ba71bda9f92134d7c7eb9217 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -13,14 +13,72 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// This macro is required to make MSVC defines math constants in math.h +#define _USE_MATH_DEFINES +#include + #include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" namespace xla { +// TODO(jlebar): Use this function in more places in this file to restrict the +// domain of other functions. +static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) { + auto& b = *operand.builder(); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + auto elem_ty = shape.element_type(); + if (!primitive_util::IsFloatingPointType(elem_ty)) { + return InvalidArgument( + "Operands to %s must be real-valued floating-point, but got %s", + op_name, PrimitiveType_Name(elem_ty)); + } + return Status::OK(); +} + +XlaOp IsPosInf(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsPosInf", operand)); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + // Note that this is only correct for floating-point types. If we wanted it + // to be correct for all types, we'd need to Gt(MaxFiniteValue). + return Eq(operand, MaxValue(&b, shape.element_type())); + }); +} + +XlaOp IsNegInf(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegInf", operand)); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + // Note that this is only correct for floating-point types. If we wanted it + // to be correct for all types, we'd need to Lt(MinFiniteValue). + return Eq(operand, MinValue(&b, shape.element_type())); + }); +} + +XlaOp IsInf(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsInf", operand)); + return IsPosInf(Abs(operand)); + }); +} + +XlaOp IsNan(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNan", operand)); + return Ne(operand, operand); + }); +} + XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); } XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); } @@ -29,44 +87,6 @@ XlaOp Square(XlaOp operand) { return operand * operand; } XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; } -namespace { - -// Polynomials for computing erf/erfc. Originally from cephes. -// Note we use float for compatibility across devices, at the cost of some -// precision for 64 bit computations. -// -// Coefficients are in descending order. -std::array kErfcPCoefficient = { - 2.46196981473530512524E-10, 5.64189564831068821977E-1, - 7.46321056442269912687E0, 4.86371970985681366614E1, - 1.96520832956077098242E2, 5.26445194995477358631E2, - 9.34528527171957607540E2, 1.02755188689515710272E3, - 5.57535335369399327526E2}; -std::array kErfcQCoefficient = { - 1.00000000000000000000E0, 1.32281951154744992508E1, - 8.67072140885989742329E1, 3.54937778887819891062E2, - 9.75708501743205489753E2, 1.82390916687909736289E3, - 2.24633760818710981792E3, 1.65666309194161350182E3, - 5.57535340817727675546E2}; -std::array kErfcRCoefficient = { - 5.64189583547755073984E-1, 1.27536670759978104416E0, - 5.01905042251180477414E0, 6.16021097993053585195E0, - 7.40974269950448939160E0, 2.97886665372100240670E0}; -std::array kErfcSCoefficient = { - 1.00000000000000000000E0, 2.26052863220117276590E0, - 9.39603524938001434673E0, 1.20489539808096656605E1, - 1.70814450747565897222E1, 9.60896809063285878198E0, - 3.36907645100081516050E0}; -std::array kErfTCoefficient = { - 9.60497373987051638749E0, 9.00260197203842689217E1, - 2.23200534594684319226E3, 7.00332514112805075473E3, - 5.55923013010394962768E4}; -std::array kErfUCoefficient = { - 1.00000000000000000000E0, 3.35617141647503099647E1, - 5.21357949780152679795E2, 4.59432382970980127987E3, - 2.26290000613890934246E4, 4.92673942608635921086E4}; -} // namespace - // Evaluate the polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { @@ -77,27 +97,86 @@ XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { return poly; } -// Compute an approximation of the error function complement (1 - erf(x)). -XlaOp Erfc(XlaOp x) { +// Computes an approximation of the error function complement (1 - erf(x)). +// +// Precondition: abs(x) >= 1. Otherwise, use ErfImpl. +// +// This follows Cephes's f32 implementation of erfc, and so it may have errors +// for double precision. +// +// See also these alternate implementations of erf and erfc: +// +// https://stackoverflow.com/questions/35148198 +// https://stackoverflow.com/questions/35966695 +// +static XlaOp ErfcImpl(XlaOp x) { + // Coefficients for erfc(f32), from Cephes. + // + // erfc(x) = exp(-x^2) P(1/x), 1 < x < 2 + static std::array kErfcPCoefficient{ + +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1, + -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1, + +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1, + }; + // erfc(x) = exp(-x^2) 1/x P(1/x^2), 2 < x < 14 + static std::array kErfcRCoefficient{ + -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0, + +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1, + -2.820767439740514E-1, +5.641895067754075E-1, + }; + XlaOp abs_x = Abs(x); XlaOp z = Exp(-x * x); + XlaOp q = ScalarLike(x, 1) / abs_x; + XlaOp y = q * q; + XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)), + EvaluatePolynomial(y, kErfcPCoefficient), + EvaluatePolynomial(y, kErfcRCoefficient)); + y = z * q * p; + return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y, y); +} - XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient); - XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient); - XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient); - XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient); - - XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps); +// Compute a polynomial approximation of the error function. +// +// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl. +// +// This follows Cephes's f32 implementation of erf, so it may have errors for +// double precision. +static XlaOp ErfImpl(XlaOp x) { + // Coefficients for by erf(f32), from Cephes. + // + // erf(x) = x P(x^2), 0 < x < 1 + static std::array kErfTCoefficient{ + +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3, + -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1, + +1.128379165726710E+0, + }; + + return x * EvaluatePolynomial(x * x, kErfTCoefficient); +} - return Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y); +XlaOp Erfc(XlaOp x) { + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x)); + // erfc(x) = + // erfc_impl(x) if x > 1 + // 1 - erf_impl(x) otherwise + return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl(x), + ScalarLike(x, 1) - ErfImpl(x)); + }); } -// Compute a polynomial approximation of the error function. XlaOp Erf(XlaOp x) { - XlaOp z = x * x; - XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient); - XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient); - return x * pt / pu; + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x)); + // erf(x) = + // erf_impl(x) if x < 1 + // 1 - erfc_impl(x) otherwise + return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl(x), + ScalarLike(x, 1) - ErfcImpl(x)); + }); } // Approximation for the inverse error function from @@ -113,37 +192,30 @@ XlaOp Erf(XlaOp x) { // } // return p*x XlaOp ErfInv(XlaOp x) { - XlaBuilder* b = x.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); - constexpr int kDegree = 9; - constexpr std::array w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - auto one = ScalarLike(x, 1.0); - auto w = -Log((one - x) * (one + x)); - - auto lt = Lt(w, ScalarLike(x, 5.0)); - auto coefficient = [&](int i) { - return Select(lt, - Broadcast(ScalarLike(x, w_less_than_5_constants[i]), - AsInt64Slice(shape.dimensions())), - Broadcast(ScalarLike(x, w_greater_than_5_constants[i]), - AsInt64Slice(shape.dimensions()))); - }; - w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = coefficient(i) + p * w; - } - return p * x; - }); + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = ScalarLike(x, 1.0); + auto w = -Log((one - x) * (one + x)); + + auto lt = Lt(w, ScalarLike(x, 5.0)); + auto coefficient = [&](int i) { + return Select(lt, FullLike(x, w_less_than_5_constants[i]), + FullLike(x, w_greater_than_5_constants[i])); + }; + w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = coefficient(i) + p * w; + } + return p * x; } namespace { @@ -170,49 +242,86 @@ static constexpr std::array kLanczosCoefficients = { // t(z) = z + kLanczosGamma + 1/2 // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) XlaOp Lgamma(XlaOp input) { - XlaOp one_half = ScalarLike(input, 0.5); - XlaOp one = ScalarLike(input, 1); - - XlaOp pi = ScalarLike(input, M_PI); - XlaOp log_pi = ScalarLike(input, std::log(M_PI)); - XlaOp log_sqrt_two_pi = ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2); - - XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); - XlaOp log_lanczos_gamma_plus_one_half = - ScalarLike(input, std::log(kLanczosGamma + 0.5)); - - XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); - - // If the input is less than 0.5 use Gauss's reflection formula: - // gamma(x) = pi / sin(pi * x) * gamma(1 - x) - XlaOp need_to_reflect = Lt(Real(input), one_half); - XlaOp z = Select(need_to_reflect, -input, input - one); - - XlaOp x = base_lanczos_coeff; - for (int i = 0; i < kLanczosCoefficients.size(); ++i) { - XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); - XlaOp index = ScalarLike(input, i); - x = x + lanczos_coefficient / (z + index + one); - } + auto& b = *input.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Lgamma", input)); + + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); + + XlaOp pi = ScalarLike(input, M_PI); + XlaOp log_pi = ScalarLike(input, std::log(M_PI)); + XlaOp log_sqrt_two_pi = + ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2); + + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Euler's reflection formula: + // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) + XlaOp need_to_reflect = Lt(input, one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp x = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + x = x + lanczos_coefficient / (z + index + one); + } - // To improve accuracy on platforms with less-precise log implementations, - // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on - // the device. - // log(t) = log(kLanczosGamma + 0.5 + z) - // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) - XlaOp t = lanczos_gamma_plus_one_half + z; - XlaOp log_t = - log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); - - XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); - - // If z = a + 0j, the analytic continuation of log reduces to taking the - // absolute value of the real part. - // Re(log(z)) = Re(log|z| + arg(z)j) - // = log|a| - XlaOp reflection = log_pi - Log(Abs(Sin(pi * input))) - log_y; - XlaOp result = Select(need_to_reflect, reflection, log_y); - return result; + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = log_lanczos_gamma_plus_one_half + + Log1p(z / lanczos_gamma_plus_one_half); + + XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); + + // Compute the reflected value, used when x < 0.5: + // + // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))). + // + // (The abs is because lgamma is the log of the absolute value of the gamma + // function.) + // + // We have to be careful when computing the final term above. gamma(x) goes + // to +/-inf at every integer x < 0, and this is controlled by the + // sin(pi * x) term. The slope is large, so precision is particularly + // important. + // + // Because abs(sin(pi * x)) has period 1, we can equivalently use + // abs(sin(pi * frac(x))) = sin(pi * frac(x)), where frac(x) is the + // fractional part of x. This is more numerically accurate: It doesn't + // overflow to inf like pi * x can, and if x is an integer, it evaluates to + // 0 exactly, which is significant because we then take the log of this + // value, and log(0) is inf. + // + // We don't have a frac(x) primitive in XLA and computing it is tricky, but + // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for + // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)). + // + XlaOp abs_input = Abs(input); + XlaOp reflection_denom = Log(Sin(pi * (abs_input - Floor(abs_input)))); + + // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf, + // then it "wins" and the result is +/-inf. + XlaOp reflection = + Select(IsFinite(reflection_denom), log_pi - reflection_denom - log_y, + -reflection_denom); + XlaOp result = Select(need_to_reflect, reflection, log_y); + + // lgamma(+/-inf) = +inf. + XlaOp inf_bcast = FullLike(input, std::numeric_limits::infinity()); + return Select(Or(IsFinite(input), // is finite, or + Not(Or(Lt(input, one), Ge(input, one)))), // is nan + result, inf_bcast); + }); } // Compute the Digamma function using Lanczos' approximation from "A Precision @@ -223,70 +332,84 @@ XlaOp Lgamma(XlaOp input) { // A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) // A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) XlaOp Digamma(XlaOp input) { - XlaOp zero = ScalarLike(input, 0); - XlaOp one_half = ScalarLike(input, 0.5); - XlaOp one = ScalarLike(input, 1); - - XlaOp pi = ScalarLike(input, M_PI); - - XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma); - XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); - XlaOp log_lanczos_gamma_plus_one_half = - ScalarLike(input, std::log(kLanczosGamma + 0.5)); - - XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); - - // If the input is less than 0.5 use Gauss's reflection formula: - // digamma(x) = digamma(1 - x) - pi * cot(pi * x) - XlaOp need_to_reflect = Lt(Real(input), one_half); - XlaOp z = Select(need_to_reflect, -input, input - one); - - XlaOp num = zero; - XlaOp denom = base_lanczos_coeff; - for (int i = 0; i < kLanczosCoefficients.size(); ++i) { - XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); - XlaOp index = ScalarLike(input, i); - num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); - denom = denom + lanczos_coefficient / (z + index + one); - } + auto& b = *input.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Digamma", input)); + + XlaOp zero = ScalarLike(input, 0); + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); + + XlaOp pi = ScalarLike(input, M_PI); + + XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma); + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Euler's reflection formula: + // digamma(x) = digamma(1 - x) - pi * cot(pi * x) + XlaOp need_to_reflect = Lt(input, one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp num = zero; + XlaOp denom = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); + denom = denom + lanczos_coefficient / (z + index + one); + } - // To improve accuracy on platforms with less-precise log implementations, - // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on - // the device. - // log(t) = log(kLanczosGamma + 0.5 + z) - // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) - XlaOp t = lanczos_gamma_plus_one_half + z; - XlaOp log_t = - log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); - - XlaOp y = log_t + num / denom - lanczos_gamma / t; - XlaOp reflection = y - pi * Cos(pi * input) / Sin(pi * input); - XlaOp result = Select(need_to_reflect, reflection, y); - return result; + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = log_lanczos_gamma_plus_one_half + + Log1p(z / lanczos_gamma_plus_one_half); + + XlaOp y = log_t + num / denom - lanczos_gamma / t; + XlaOp reflection = y - pi * Cos(pi * input) / Sin(pi * input); + return Select(need_to_reflect, reflection, y); + }); } // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. XlaOp RoundToEven(XlaOp x) { - auto half = xla::ScalarLike(x, 0.5); - auto one = xla::ScalarLike(x, 1.0); - auto two = xla::ScalarLike(x, 2.0); - - auto round_val = xla::Floor(x); - auto fraction = x - round_val; - auto nearest_even_int = round_val - two * xla::Floor(half * x); - auto is_odd = xla::Eq(nearest_even_int, one); - return xla::Select(xla::Or(xla::Gt(fraction, half), - xla::And(xla::Eq(fraction, half), is_odd)), - round_val + one, round_val); + auto& b = *x.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + // Reject non-real non-fp inputs (What does it even mean to round a complex + // number? Do you round each component equally? In that case, you should + // just ask for that explicitly.) + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("RoundToEven", x)); + + auto half = ScalarLike(x, 0.5); + auto one = ScalarLike(x, 1.0); + auto two = ScalarLike(x, 2.0); + + auto round_val = Floor(x); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * Floor(half * x); + auto is_odd = Eq(nearest_even_int, one); + return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), + round_val + one, round_val); + }); } // Trigonometric functions. -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 +// pi if x == -1 XlaOp Acos(XlaOp x) { - return ScalarLike(x, 2.0) * - Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), ScalarLike(x, 1.0) + x); + return Select(Ne(x, FullLike(x, -1)), + ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), + ScalarLike(x, 1.0) + x), + FullLike(x, M_PI)); } // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) @@ -320,4 +443,92 @@ XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); } XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); } +XlaOp MaybeConjugate(XlaOp x, bool conjugate) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + auto perform_conj = + primitive_util::IsComplexType(shape.element_type()) && conjugate; + return perform_conj ? Conj(x) : x; + }); +} + +XlaOp NextAfter(XlaOp from, XlaOp to) { + auto builder = from.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(from)); + int bitwidth = primitive_util::BitWidth(shape.element_type()); + auto int_type = primitive_util::UnsignedIntegralTypeForBitWidth(bitwidth); + auto from_as_int = BitcastConvertType(from, int_type); + auto to_as_int = BitcastConvertType(to, int_type); + + // The result is NaN if either "from" or "to" are NaN. + auto from_is_nan = Ne(from, from); + auto to_is_nan = Ne(to, to); + auto nan_input = Or(from_is_nan, to_is_nan); + auto result_for_nan = + Broadcast(ScalarLike(from, std::numeric_limits::quiet_NaN()), + shape.dimensions()); + result_for_nan = BitcastConvertType(result_for_nan, int_type); + + // The sign bit is the MSB. + const int64 sign_mask = int64{1} << (bitwidth - 1); + // Discard the sign bit to make the result non-negative. + auto from_abs = And(from_as_int, ScalarLike(from_as_int, ~sign_mask)); + auto to_abs = And(to_as_int, ScalarLike(to_as_int, ~sign_mask)); + + // When both "from" and "to" are equal, the result is "to". + // N.B. It would not make a difference if we chose the result to be "from". + auto from_and_to_are_equal = Eq(from_as_int, to_as_int); + auto result_for_equal = to_as_int; + + // When both "from" and "to" are both 0, the result is "to". This ensures we + // get a zero signed like "to". + auto from_is_zero = Eq(from_abs, ZerosLike(from_abs)); + auto to_is_zero = Eq(to_abs, ZerosLike(to_abs)); + auto result_for_both_zero = to_as_int; + + auto from_sign = And(from_as_int, ScalarLike(from_as_int, sign_mask)); + auto to_sign = And(to_as_int, ScalarLike(to_as_int, sign_mask)); + + // If from == 0 && to != 0, we need to return the smallest subnormal number + // signed like "to". + auto result_for_from_zero_to_non_zero = + Or(to_sign, ScalarLike(from_as_int, 1)); + + // If the sign of "from" and "to" disagree: + // - we need to make the magnitude of "from" smaller so that it is closer to + // zero. + // + // Otherwise the signs agree: + // - "from" with a magnitude larger than "to" means we need to make the + // magnitude smaller. + // - "from" with a magnitude smaller than "to" means we need to make the + // magnitude larger. + // - "from" with the same magnitude and sign as "to" has already been + // handled. + auto signs_disagree = Ne(from_sign, to_sign); + auto from_magnitude_larger_than_to = Gt(from_abs, to_abs); + auto result_has_smaller_magnitude = + Or(from_magnitude_larger_than_to, signs_disagree); + auto magnitude_adjustment = + Select(result_has_smaller_magnitude, + Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()), + Broadcast(ScalarLike(from_as_int, 1), shape.dimensions())); + auto result = Add(from_as_int, magnitude_adjustment); + // Handle from == ±0. + result = Select(from_is_zero, + Select(to_is_zero, result_for_both_zero, + result_for_from_zero_to_non_zero), + result); + // Handle from == to. + result = Select(from_and_to_are_equal, result_for_equal, result); + // Handle isnan(from) || isnan(to). + result = Select(nan_input, result_for_nan, result); + + // Cast back to the original type. + return BitcastConvertType(result, shape.element_type()); + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 3f06d04b9ae98b3aa75e68cd07810b2b4c24d280..907571c9a3ec65b0be0087ad4837c842a0bdcc79 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -20,6 +20,18 @@ limitations under the License. namespace xla { +// Determines whether operand is +/-inf or nan. +// +// Raises an error if called on integral or complex values. +XlaOp IsPosInf(XlaOp operand); +XlaOp IsNegInf(XlaOp operand); +XlaOp IsInf(XlaOp operand); +XlaOp IsNan(XlaOp operand); + +// Returns the next number after 'from' in the direction of 'to' the same way +// std::nextafter(from, to) would. +XlaOp NextAfter(XlaOp from, XlaOp to); + // Computes the square root of 'operand'. XlaOp Sqrt(XlaOp operand); @@ -32,7 +44,7 @@ XlaOp Square(XlaOp operand); // Computes the reciprocal of 'operand'. XlaOp Reciprocal(XlaOp operand); -// Evaluates a polynomial given coefficients and `x`. +// Evaluates a polynomial given coefficients and 'x'. // N.B. Coefficients should be supplied in decreasing order. XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients); @@ -86,6 +98,10 @@ XlaOp Cosh(XlaOp x); // Computes the hyperbolic sine of 'x'. XlaOp Sinh(XlaOp x); +// Applies a complex conjugation operation if 'a' is complex and 'conjugate' +// is true, otherwise returns its argument. +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4ed0db56f7026d2c397c5beb1cc7ea3e4b06fee --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc @@ -0,0 +1,187 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using Eigen::half; + +struct Testcase { + Testcase(string name, XlaOp (*op)(XlaOp), float (*host_op)(float)) + : name(name), op(op), host_op(host_op) {} + + Testcase& set_tolerance(float abs_err, float rel_err) { + error.abs = abs_err; + error.rel = rel_err; + return *this; + } + + Testcase& set_relaxed_nans() { + error.relaxed_nans = true; + return *this; + } + + Testcase& set_fewer_infs_ok() { + error.fewer_infs_ok = true; + return *this; + } + + Testcase& set_skip_pos_inf() { + skip_pos_inf = true; + return *this; + } + + Testcase& set_skip_neg_inf() { + skip_neg_inf = true; + return *this; + } + + Testcase& set_skip_infs() { + skip_pos_inf = true; + skip_neg_inf = true; + return *this; + } + + Testcase& set_skip_neg_zero() { + skip_neg_zero = true; + return *this; + } + + string name; + XlaOp (*op)(XlaOp); + float (*host_op)(float); + + ErrorSpec error{0.01, 0.01}; + + // If true, don't test +/-infinity or negative 0. + bool skip_pos_inf = false; + bool skip_neg_inf = false; + bool skip_neg_zero = false; +}; + +void PrintTo(const Testcase& tc, std::ostream* os) { *os << tc.name; } + +class MathExhaustiveTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + public: + MathExhaustiveTest() { + // Disable fast-math, otherwise we get the wrong results for e.g. + // sqrt(-inf). + SetFastMathDisabled(true); + } +}; + +// Checks a function's behavior on all fp16 values. +// +// TODO(jlebar): asin and lgamma tests fail on interpreter. +XLA_TEST_P(MathExhaustiveTest, DISABLED_ON_INTERPRETER(F16)) { + const Testcase& tc = GetParam(); + XlaBuilder b(TestName()); + + std::vector input; + for (uint32 i = 0; i < 1 << 16; ++i) { + half h; + h.x = i; + + // If we're not using infinity as an input, use 0 as a placeholder rather + // than simply skipping this element. We do this because when the test + // framework reports an incorrect answer, it tells us which index failed. + // So long as our inputs are a simple list of all possible float16s, we can + // convert an index to a half with e.g. the following Python: + // + // np.frombuffer(array('H', [12345]), dtype=np.float16)[0] + // + // but as soon as our list of inputs has any gaps, this doesn't work. + if (std::isinf(static_cast(h)) && + ((tc.skip_pos_inf && h > half{0}) || + (tc.skip_neg_inf && h < half{0}))) { + h = half{0}; + } + + if (h == half{0} && tc.skip_neg_zero && + std::signbit(static_cast(h))) { + h = half{0}; + } + + input.push_back(h); + } + + std::vector expected_result; + for (const auto& h : input) { + expected_result.push_back( + static_cast(tc.host_op(static_cast(h)))); + } + + XlaOp param = AddParam(LiteralUtil::CreateR1(input), &b); + tc.op(param); + ComputeAndCompareR1(&b, expected_result, {}, tc.error); +} + +// TODO(b/123355973): The following tests from math.cc are missing. +// +// - Many failures. +// +// Testcase{"acosh", Acosh, std::acosh}.set_relaxed_nans(), +// Testcase{"asinh", Asinh, std::asinh}, +// Testcase{"sinh", Sinh, std::sinh}, +// Testcase{"cosh", Cosh, std::cosh}.set_fewer_infs_ok(), +// Testcase{"round_to_even", RoundToEven, +// [](float x) { return std::nearbyint(x / 2) * 2; }}, +// +// - No equivalent std function to compare with. +// +// Testcase{"erfinv", ErfInv, std::erfinv}, +// Testcase{"digamma", Digamma, std::digamma}, +// +// - Needs a special test (function takes two args, and simply computing in f32 +// and downcasting to f16 doesn't give the correct answer). +// +// Testcase{"nextafter", NextAfter, std::nextafter}, +// +// TODO(b/123355973): Test math functions not from math.cc (e.g. log). +// TODO(b/123355973): Test bf16 and f32. +// TODO(b/123355973): Get rid of skip_infs / skip_neg_zero below if possible. +// TODO(b/123355973): Reduce lgamma error if possible; it is very high. +INSTANTIATE_TEST_CASE_P( + MathExhaustiveTest_Instantiation, MathExhaustiveTest, + ::testing::ValuesIn(std::vector{ + Testcase{"sqrt", Sqrt, std::sqrt}.set_skip_neg_inf(), + Testcase{"rsqrt", Rsqrt, [](float x) { return 1 / std::sqrt(x); }} + .set_tolerance(0.05, 0.05) + .set_skip_infs() + .set_skip_neg_zero(), + Testcase{"square", Square, [](float x) { return x * x; }}, + Testcase{"reciprocal", Reciprocal, [](float x) { return 1 / x; }}, + Testcase{"erf", Erf, std::erf}.set_tolerance(0.001, 0.0001), + Testcase{"erfc", Erfc, std::erfc}.set_tolerance(0.001, 0.0001), + Testcase{"lgamma", Lgamma, std::lgamma} + .set_tolerance(0.1, 0.15) + .set_fewer_infs_ok(), + Testcase{"asin", Asin, std::asin}.set_skip_infs(), + Testcase{"acos", Acos, std::acos}.set_skip_infs(), + Testcase{"atan", Atan, std::atan}, + Testcase{"tan", Tan, std::tan}.set_tolerance(0.05, 0.05), + })); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index ae2ea225d1aadd7b3a794eabeca866c498f34760..364ac5876abbec825834081518a6dfda84356048 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -30,6 +31,121 @@ class MathTest : public ClientLibraryTestBase { ErrorSpec error_spec_{0.0001}; }; +// Write TYPED_TESTs within the class definition so that we don't have to litter +// "this->" everywhere. +template +class MathTypedTest : public MathTest { + public: + void TestLogEdgeCases() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + Log(AddParam(LiteralUtil::CreateR1({T{0.0}, T{-0.0}}), &b)); + ComputeAndCompareR1(&b, + {-std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}, + {}, error_spec_); + } + + void TestLog1pEdgeCases() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + Log1p(AddParam(LiteralUtil::CreateR1({T{0.0}, T{-0.0}, T{-1.0}}), &b)); + ComputeAndCompareR1( + &b, {T{0.0}, T{-0.0}, -std::numeric_limits::infinity()}, {}, + error_spec_); + } + + void TestIsInfOrNan() { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + auto x = + ConstantR1(&b, { + T{0}, + T{100}, + T{-1000}, + T{std::numeric_limits::max()}, + T{std::numeric_limits::lowest()}, + T{std::numeric_limits::infinity()}, + T{-std::numeric_limits::infinity()}, + T{std::numeric_limits::quiet_NaN()}, + T{std::numeric_limits::signaling_NaN()}, + }); + Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)}); + + auto expected = LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR1( + {true, true, true, true, true, false, false, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, true, true, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, true, false, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, false, true, false, false}), + LiteralUtil::CreateR1( + {false, false, false, false, false, false, false, true, true})); + ComputeAndCompareLiteral(&b, expected, {}); + } +}; + +// TODO(b/123355973): Add bfloat16 to TestTypes once it's working. +#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +using TestTypes = ::testing::Types; +#else +using TestTypes = ::testing::Types; +#endif + +TYPED_TEST_CASE(MathTypedTest, TestTypes); + +XLA_TYPED_TEST(MathTypedTest, LogEdgeCases) { this->TestLogEdgeCases(); } +XLA_TYPED_TEST(MathTypedTest, Log1pEdgeCases) { this->TestLog1pEdgeCases(); } +XLA_TYPED_TEST(MathTypedTest, IsInfOrNan) { this->TestIsInfOrNan(); } + +// Check that certain ops only support real, floating-point inputs. +// +// TODO(jlebar): Expand this test to cover more ops. +XLA_TEST_F(MathTest, RealFpOnlyOps) { + for (int64 i = PrimitiveType_MIN; i <= PrimitiveType_MAX; ++i) { + auto ty = static_cast(i); + SCOPED_TRACE(PrimitiveType_Name(ty)); + Shape shape; + if (primitive_util::IsArrayType(ty)) { + shape = ShapeUtil::MakeShape(ty, {42}); + } else if (ty == PrimitiveType::TUPLE) { + shape = ShapeUtil::MakeTupleShape({}); + } else if (ty == PrimitiveType::OPAQUE) { + shape = ShapeUtil::MakeOpaqueShape(); + } else if (ty == PrimitiveType::TOKEN) { + shape = ShapeUtil::MakeTokenShape(); + } else { + continue; + } + + for (const auto& test : + std::vector, string>>({ + {IsFinite, "is_finite"}, + {IsInf, "is_inf"}, + {IsPosInf, "is_pos_inf"}, + {IsNegInf, "is_neg_inf"}, + {IsNan, "is_nan"}, + {Erf, "erf"}, + {Erfc, "erfc"}, + {Lgamma, "lgamma"}, + {Digamma, "digamma"}, + {RoundToEven, "round_to_even"}, + })) { + SCOPED_TRACE(test.second); + XlaBuilder b(TestName()); + XlaOp p = Parameter(&b, 0, shape, "p0"); + test.first(p); + + EXPECT_EQ(b.first_error().ok(), primitive_util::IsFloatingPointType(ty)); + } + } +} + XLA_TEST_F(MathTest, SqrtF32) { XlaBuilder builder(TestName()); Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); @@ -106,6 +222,27 @@ XLA_TEST_F(MathTest, Lgamma) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, LgammaF16) { + SetFastMathDisabled(true); + + XlaBuilder b(TestName()); + + // These seemingly arbitrary inputs came from debugging the lgamma + // implementation against a test which tried all possible f16 values. + auto x = ConstantR1(&b, { + half(-7360.0), + half(-4066.0), + half(-5.9605e-08), + }); + Lgamma(x); + std::vector expected = { + std::numeric_limits::infinity(), + std::numeric_limits::infinity(), + half(16.64), + }; + ComputeAndCompareR1(&b, expected, {}, ErrorSpec{0.1}); +} + XLA_TEST_F(MathTest, Digamma) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1.0, 0.5, 1 / 3.0, 0.25, 1 / 6.0, 0.125, @@ -148,5 +285,40 @@ XLA_TEST_F(MathTest, RoundToEven) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, ErfRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Erf(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, ErfcRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Erfc(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, LgammaRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Lgamma(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, DigammaRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + Digamma(x); + EXPECT_FALSE(b.Build().status().ok()); +} + +XLA_TEST_F(MathTest, RoundToEvenRejectsComplexInputs) { + XlaBuilder b(TestName()); + auto x = ConstantR1>(&b, {{0, 0}}); + RoundToEven(x); + EXPECT_FALSE(b.Build().status().ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5aea96090c59c78d20cfc10a4bd6b312be592c1 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -0,0 +1,339 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/matrix.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, + int64 n) { + auto a = Iota(builder, U32, m); + auto b = Iota(builder, U32, n); + auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); + return ConvertElementType(indicator, type); +} + +XlaOp GetMatrixDiagonal(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + auto mask = Broadcast(indicator, major_dims); + + // TPUs don't support S64 add reduction at the moment. But fortunately + // OR-reductions work just as well for integers. + XlaComputation reducer = + primitive_util::IsIntegralType(shape.element_type()) + ? CreateScalarOrComputation(shape.element_type(), builder) + : CreateScalarAddComputation(shape.element_type(), builder); + + return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), + reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + }); +} + +XlaOp TriangleMask(XlaOp x, int diagonal) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, S32, n); + auto b = Iota(builder, S32, m) + ConstantR0(builder, diagonal); + XlaOp indicator; + indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + return Broadcast(indicator, major_dims); + }); +} + +XlaOp Triangle(XlaOp x, bool lower) { + return lower ? Select(TriangleMask(x, 0), x, ZerosLike(x)) + : Select(TriangleMask(x, -1), ZerosLike(x), x); +} + +XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } + +XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } + +Status ValidateEinsumNumericDimensions(absl::Span x_config, + absl::Span y_config, + absl::Span output_config) { + for (auto dim : output_config) { + if (absl::c_linear_search(x_config, dim) || + absl::c_linear_search(y_config, dim)) { + if (absl::c_count(output_config, dim) > 1) { + return InvalidArgument("Einsum has repeated output dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has output dimension without corresponding input dimension."); + } + for (auto dim : x_config) { + if (absl::c_linear_search(y_config, dim) || + absl::c_linear_search(output_config, dim)) { + if (absl::c_count(x_config, dim) > 1) { + return InvalidArgument("Einsum has repeated lhs dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has lhs dimension without corresponding rhs or output " + "dimension."); + } + for (auto dim : y_config) { + if (absl::c_linear_search(x_config, dim) || + absl::c_linear_search(output_config, dim)) { + if (absl::c_count(y_config, dim) > 1) { + return InvalidArgument("Einsum has repeated rhs dimension."); + } + continue; + } + return InvalidArgument( + "Einsum has rhs dimension without corresponding lhs or output " + "dimension."); + } + return Status::OK(); +} + +xla::XlaOp Einsum(xla::XlaOp x, absl::Span x_config, xla::XlaOp y, + absl::Span y_config, + absl::Span output_config, + xla::PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR( + ValidateEinsumNumericDimensions(x_config, y_config, output_config)); + const int64 x_rank = x_config.size(); + const int64 y_rank = y_config.size(); + const int64 output_rank = output_config.size(); + absl::flat_hash_set x_map; + absl::flat_hash_set y_map; + absl::flat_hash_set output_map; + + auto find = [&](const absl::flat_hash_set& map, int64 d) { + return map.count(d) != 0; + }; + + auto insert = [&](absl::flat_hash_set& map, char d) { + CHECK(!find(map, d)); + map.insert(d); + }; + + for (auto d : x_config) { + insert(x_map, d); + } + + for (auto d : y_config) { + insert(y_map, d); + } + + for (auto d : output_config) { + insert(output_map, d); + } + + DotDimensionNumbers dnums; + std::vector lhs_outer_dims; + auto is_batch_dim = [&](int64 d) { + return find(x_map, d) && find(y_map, d) && find(output_map, d); + }; + auto is_contracting = [&](int64 d) { + return find(x_map, d) && find(y_map, d); + }; + auto rhs_dimension_number = [&](int64 d) { + return absl::c_find(y_config, d) - y_config.begin(); + }; + for (int64 i = 0; i < x_rank; ++i) { + auto dim_name = x_config[i]; + if (is_batch_dim(dim_name)) { + dnums.add_lhs_batch_dimensions(i); + dnums.add_rhs_batch_dimensions(rhs_dimension_number(dim_name)); + } else if (is_contracting(dim_name)) { + dnums.add_lhs_contracting_dimensions(i); + dnums.add_rhs_contracting_dimensions(rhs_dimension_number(dim_name)); + } else { + lhs_outer_dims.push_back(i); + } + } + + std::vector rhs_outer_dims; + for (int64 i = 0; i < y_rank; ++i) { + auto dim_name = y_config[i]; + if (!is_batch_dim(dim_name) && !is_contracting(dim_name)) { + rhs_outer_dims.push_back(i); + } + } + + auto output_dimension_number = [&](char d) { + return absl::c_find(output_config, d) - output_config.begin(); + }; + + std::vector output_dims; + output_dims.reserve(output_rank); + for (auto d : dnums.lhs_batch_dimensions()) { + output_dims.push_back(output_dimension_number(x_config[d])); + } + for (auto d : lhs_outer_dims) { + output_dims.push_back(output_dimension_number(x_config[d])); + } + for (auto d : rhs_outer_dims) { + output_dims.push_back(output_dimension_number(y_config[d])); + } + + std::vector transpose_dims(output_rank); + for (int64 i = 0; i < output_rank; ++i) { + transpose_dims[output_dims[i]] = i; + } + + PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + return Transpose(DotGeneral(x, y, dnums, &precision_proto), transpose_dims); + }); +} + +XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + const int ndims = x_shape.rank(); + batch_dimension_numbers.reserve(ndims - 2); + for (int i = 0; i < ndims - 2; ++i) { + batch_dimension_numbers.push_back(i); + } + std::vector x_config = batch_dimension_numbers; + x_config.push_back(ndims - 2); + x_config.push_back(ndims); + std::vector y_config = batch_dimension_numbers; + y_config.push_back(ndims); + y_config.push_back(ndims - 1); + std::vector output_config = batch_dimension_numbers; + output_config.push_back(ndims - 2); + output_config.push_back(ndims - 1); + return Einsum(x, x_config, y, y_config, output_config, precision); + }); +} + +StatusOr, 3>> ParseEinsumString( + absl::string_view einsum_config) { + std::array, 3> einsum_config_numeric; + std::vector main_split = + absl::StrSplit(einsum_config, ','); + + if (main_split.size() != 2) { + return InvalidArgument("Expected one \",\" in einsum_config."); + } + + auto maybe_invalid_character = [](char d) { + if (absl::ascii_isalpha(d)) { + return Status::OK(); + } + if (d == '.') { + return InvalidArgument("Unsupported \"...\" or \".\" in einsum config."); + } + return InvalidArgument("Unexpected character in einsum config."); + }; + + auto& x_config = einsum_config_numeric[0]; + x_config.reserve(main_split[0].size()); + for (auto d : main_split[0]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + x_config.push_back(static_cast(d)); + } + std::vector y_output_split = + absl::StrSplit(main_split[1], "->"); + if (y_output_split.size() != 2) { + return InvalidArgument("Expected one \"->\" in einsum_config."); + } + auto& y_config = einsum_config_numeric[1]; + y_config.reserve(y_output_split[0].size()); + for (auto d : y_output_split[0]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + y_config.push_back(static_cast(d)); + } + auto& output_config = einsum_config_numeric[2]; + output_config.reserve(y_output_split[1].size()); + for (auto d : y_output_split[1]) { + TF_RETURN_IF_ERROR(maybe_invalid_character(d)); + output_config.push_back(static_cast(d)); + } + return einsum_config_numeric; +} + +XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto einsum_config_numeric, + ParseEinsumString(einsum_config)); + return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1], + einsum_config_numeric[2], precision); + }); +} + +XlaOp TransposeInMinorDims(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return Transpose(x, permutation); + }); +} + +XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) { + return transpose ? TransposeInMinorDims(x) : x; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..491f1eab4cbffbbf9df70d4c35a61351df3e98aa --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ + +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere +// else. +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); + +// Get the diagonals of the last two dimensions. If 'x' has shape +// [..., M, N], then the output has shape [..., min(M, N)], containing the +// diagonal elements (i.e., with indices [..., i, i]). +XlaOp GetMatrixDiagonal(XlaOp x); + +// Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal +// and false above that diagonal. +XlaOp TriangleMask(XlaOp x, int diagonal); + +// Get the upper or lower triangle part of the last two dimensions +XlaOp Triangle(XlaOp x, bool lower); + +// Get the upper triangle part of the last two dimensions +XlaOp UpperTriangle(XlaOp x); + +// Get the lower triangle part of the last two dimensions +XlaOp LowerTriangle(XlaOp x); + +// Multiplies slices of two tensors in batches. + +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if transpose_x else r_x +// c_o = r_y if transpose_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Parse an einsum string into dimension numbers: +// "ab,cb->ac" +// becomes: +// {{0, 1},{2, 1},{0, 2}} +// +// NOTE: This function is meant for testing, there is no need to call it +// directly. + +StatusOr, 3>> ParseEinsumString( + absl::string_view einsum_config); + +// Determine if each dimension label is in at least two inputs. +// +// NOTE: This function is meant for testing, there is no need to call it +// directly. +Status ValidateEinsumNumericDimensions(absl::Span x_config, + absl::Span y_config, + absl::Span output_config); + +// Supports two operand einsum notation like "ab,cb->ac". +xla::XlaOp Einsum( + xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Same as above but supporting numeric labels on dimensins. So "ab,cb->ac" +// becomes: +// x_config = {0, 1} +// y_config = {2, 1} +// output_config = {0, 2} +xla::XlaOp Einsum( + xla::XlaOp x, absl::Span x_config, xla::XlaOp y, + absl::Span y_config, absl::Span output_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Transposes a stack of matrices `x` by swapping the last two dimensions. +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); + +// Transposes `x` in its minor dimensions if `transpose` is true, otherwise +// returns `x` unchanged. +xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..79cf529ee94b044ee0af788522200cd28c778997 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -0,0 +1,181 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/matrix.h" + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class MatrixTest : public ClientLibraryTestBase { + protected: + template + void TestMatrixDiagonal(); +}; + +XLA_TEST_F(MatrixTest, Triangle) { + XlaBuilder builder(TestName()); + Array3D input(2, 3, 4); + input.FillIota(0); + + XlaOp a; + auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); + LowerTriangle(a); + Array3D expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}}, + {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}}); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}); +} + +template +void MatrixTest::TestMatrixDiagonal() { + XlaBuilder builder("GetMatrixDiagonal"); + Array3D input(2, 3, 4); + input.FillIota(0); + + XlaOp a; + auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); + GetMatrixDiagonal(a); + Array2D expected({{0, 5, 10}, {12, 17, 22}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}); +} + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } + +Array3D BatchedAValsFull() { + return {{ + {2, 0, 1, 2}, + {3, 6, 0, 1}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }}; +} + +XLA_TEST_F(MatrixTest, RowBatchDot) { + XlaBuilder builder(TestName()); + + int n = 4; + + XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + + auto l_index = DynamicSliceInMinorDims( + a, {index, ConstantR0(&builder, 0)}, {1, n}); + BatchDot(l_index, TransposeInMinorDims(row)); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} + +XLA_TEST_F(MatrixTest, Einsum) { + XlaBuilder builder(TestName()); + + int n = 4; + + XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + + auto l_index = DynamicSliceInMinorDims( + a, {index, ConstantR0(&builder, 0)}, {1, n}); + Einsum(l_index, row, "abc,adc->abd"); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} + +XLA_TEST_F(MatrixTest, ParseEinsumString) { + auto to_vec = [](absl::string_view s) { + std::vector v; + v.reserve(s.size()); + for (auto c : s) { + v.push_back(int64{c}); + } + return v; + }; + + auto to_string = [&](absl::string_view x, absl::string_view y, + absl::string_view o) { + return absl::StrCat(x, ",", y, "->", o); + }; + + std::vector> good_test_cases = {{"ab", "bc", "ac"}, + {"Bab", "Bbc", "Bac"}, + {"ab", "cd", "dcba"}, + {"abc", "abd", "cbd"}}; + for (auto test_case : good_test_cases) { + auto parse_result_or_status = + ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2])); + EXPECT_TRUE(parse_result_or_status.status().ok()); + auto parse_result = parse_result_or_status.ValueOrDie(); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(parse_result[i], to_vec(test_case[i])); + } + EXPECT_TRUE(ValidateEinsumNumericDimensions( + parse_result[0], parse_result[1], parse_result[2]) + .ok()); + } + + std::vector einsum_strings_that_fail_parsing = { + "", "a", "ab->ba", "ab,bc,cd->ad", "a...b,bc->a...c"}; + for (auto test_case : einsum_strings_that_fail_parsing) { + auto parse_result_or_status = ParseEinsumString(test_case); + EXPECT_FALSE(parse_result_or_status.status().ok()); + } + + std::vector einsum_strings_that_fail_numeric_validation = { + "a,b->c", "ab,bc->acd", "abz,bc->ac", "ab,bcz->ac"}; + for (auto test_case : einsum_strings_that_fail_numeric_validation) { + auto parse_result_or_status = ParseEinsumString(test_case); + EXPECT_TRUE(parse_result_or_status.status().ok()); + auto parse_result = parse_result_or_status.ValueOrDie(); + EXPECT_FALSE(ValidateEinsumNumericDimensions( + parse_result[0], parse_result[1], parse_result[2]) + .ok()); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc deleted file mode 100644 index 377654220b5df4487e9e194361473d54ff46a54e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/numeric.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" - -namespace xla { - -XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, - int64 n) { - auto a = Iota(builder, type, m); - auto b = Iota(builder, type, n); - auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); - return ConvertElementType(indicator, type); -} - -XlaOp GetMatrixDiagonal(XlaOp x) { - XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - const int64 m = shape.dimensions(n_dims - 2); - const int64 n = shape.dimensions(n_dims - 1); - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); - auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - auto mask = Broadcast(indicator, major_dims); - - // TPUs don't support S64 add reduction at the moment. But fortunately - // OR-reductions work just as well for integers. - XlaComputation reducer = - primitive_util::IsIntegralType(shape.element_type()) - ? CreateScalarOrComputation(shape.element_type(), builder) - : CreateScalarAddComputation(shape.element_type(), builder); - - return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), - reducer, {m >= n ? n_dims - 2 : n_dims - 1}); - }); -} - -XlaOp Triangle(XlaOp x, bool lower) { - XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - const int64 m = shape.dimensions(n_dims - 2); - const int64 n = shape.dimensions(n_dims - 1); - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); - xla::XlaOp indicator; - if (lower) { - indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } else { - indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } - auto mask = Broadcast(indicator, major_dims); - - return Select(mask, x, Zeros(builder, shape)); - }); -} - -XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } - -XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h deleted file mode 100644 index f62fdab4b0e5e84347cfaa1424a8c2e5c58dd3ce..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/numeric.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere -// else. -XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); - -// Get the diagonals of the last two dimensions. If 'x' has shape -// [..., M, N], then the output has shape [..., min(M, N)], containing the -// diagonal elements (i.e., with indices [..., i, i]). -XlaOp GetMatrixDiagonal(XlaOp x); - -// Get the upper or lower triangle part of the last two dimensions -XlaOp Triangle(XlaOp x, bool lower); - -// Get the upper triangle part of the last two dimensions -XlaOp UpperTriangle(XlaOp x); - -// Get the lower triangle part of the last two dimensions -XlaOp LowerTriangle(XlaOp x); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc deleted file mode 100644 index 7d6aedd49462bd4f075f90d0b0f85c40f1191aa1..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/numeric_test.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { -namespace { - -class NumericTest : public ClientLibraryTestBase { - protected: - template - void TestMatrixDiagonal(); -}; - -XLA_TEST_F(NumericTest, Triangle) { - XlaBuilder builder(TestName()); - Array3D input(2, 3, 4); - input.FillIota(0); - - XlaOp a; - auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); - LowerTriangle(a); - Array3D expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}}, - {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}}); - - ComputeAndCompareR3(&builder, expected, {a_data.get()}); -} - -template -void NumericTest::TestMatrixDiagonal() { - XlaBuilder builder("GetMatrixDiagonal"); - Array3D input(2, 3, 4); - input.FillIota(0); - - XlaOp a; - auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); - GetMatrixDiagonal(a); - Array2D expected({{0, 5, 10}, {12, 17, 22}}); - - ComputeAndCompareR2(&builder, expected, {a_data.get()}); -} - -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } - -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } - -XLA_TEST_F(NumericTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc similarity index 55% rename from tensorflow/compiler/tf2xla/lib/qr.cc rename to tensorflow/compiler/xla/client/lib/qr.cc index 6b3f2b6e065b5c99e2d0248237369ecc30188aa5..640412ec8bcffd2565b11ba25b87f6bf6438d848 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -13,18 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -32,10 +31,18 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/errors.h" -namespace tensorflow { +namespace xla { namespace { +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); + return output; +} + // Computes a Householder reflection of the form: // H = I - tau v v.T. // such that @@ -65,52 +72,47 @@ namespace { // return (v, tau, beta) // TODO(phawkins): LAPACK's xLARFG implementation has code for handling // overflows in the norm/beta calculations. Perhaps do the same here. -xla::Status House(xla::XlaOp x, xla::XlaOp k, - absl::Span batch_dims, const int64 m, - xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) { - xla::XlaBuilder* const builder = x.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - const xla::PrimitiveType type = x_shape.element_type(); +Status House(XlaOp x, XlaOp k, absl::Span batch_dims, + const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) { + XlaBuilder* const builder = x.builder(); + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + const PrimitiveType type = x_shape.element_type(); std::vector batch_dim_ids(batch_dims.size()); std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); const int64 minor_dim = batch_dims.size(); - xla::XlaOp zero = xla::ScalarLike(x, 0.0); - xla::XlaOp one = xla::ScalarLike(x, 1.0); + XlaOp zero = ScalarLike(x, 0.0); + XlaOp one = ScalarLike(x, 1.0); // alpha = x[k] - xla::XlaOp alpha = - xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); + XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); // Compute x[k+1:] (padded with zeros in elements 0..k) - xla::XlaOp iota = xla::Iota(builder, xla::S32, m); - xla::XlaOp x_after_k = - xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type), - /*broadcast_dimensions=*/{minor_dim}); + XlaOp iota = Iota(builder, S32, m); + XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type), + /*broadcast_dimensions=*/{minor_dim}); // sigma = np.dot(x[k+1:], x[k+1:]) - auto sigma = - xla::Reduce(x_after_k * x_after_k, zero, - xla::CreateScalarAddComputation(type, builder), {minor_dim}); + auto sigma = Reduce(x_after_k * x_after_k, zero, + CreateScalarAddComputation(type, builder), {minor_dim}); // mu = np.sqrt(x[k]*x[k] + sigma) - auto mu = xla::Sqrt(xla::Square(alpha) + sigma); + auto mu = Sqrt(Square(alpha) + sigma); - auto sigma_is_zero = xla::Eq(sigma, zero); + auto sigma_is_zero = Eq(sigma, zero); - *beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu); - *tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims), - (*beta - alpha) / *beta); - auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims), - alpha - *beta); + *beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu); + *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims), + (*beta - alpha) / *beta); + auto divisor = + Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta); - auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type), - std::vector(batch_dims.size(), 1)); + auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type), + std::vector(batch_dims.size(), 1)); // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. - *v = e_k + - xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); + *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); return Status::OK(); } @@ -143,94 +145,86 @@ xla::Status House(xla::XlaOp x, xla::XlaOp k, // return (q, vs, taus) struct QRBlockResult { // The factored R value - xla::XlaOp r; + XlaOp r; // Representation of the Householder matrices I - beta v v.T - xla::XlaOp taus; // Shape: [..., n] - xla::XlaOp vs; // Shape: [..., m, n] + XlaOp taus; // Shape: [..., n] + XlaOp vs; // Shape: [..., m, n] }; -xla::StatusOr QRBlock( - xla::XlaOp a, xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int num_dims = xla::ShapeUtil::Rank(a_shape); +StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = a_shape.rank(); if (num_dims < 2) { - return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", - num_dims); + return InvalidArgument("Argument to QR must have rank >= 2; got shape %s", + a_shape.ToString()); } - xla::PrimitiveType type = a_shape.element_type(); + PrimitiveType type = a_shape.element_type(); - const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); } std::vector batch_dim_indices(num_batch_dims); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - auto qr_body_fn = - [&](xla::XlaOp j, absl::Span values, - xla::XlaBuilder* builder) -> xla::StatusOr> { + auto qr_body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { auto a = values[0]; auto vs = values[1]; auto taus = values[2]; // v, beta = house(a[:, j], j) auto x = DynamicSliceInMinorDims(a, {j}, {1}); - xla::XlaOp v, tau, beta; - TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j, + XlaOp v, tau, beta; + TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j, batch_dims, m, &v, &tau, &beta)); std::vector shape = batch_dims; shape.push_back(1); shape.push_back(m); - auto v_broadcast = xla::Reshape(v, shape); + auto v_broadcast = Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = - BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - vva = - BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - a = a - xla::Mul(tau, vva, - /*broadcast_dimensions=*/batch_dim_indices); + auto vva = BatchDot(v_broadcast, a, precision); + vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision); + a = a - Mul(tau, vva, + /*broadcast_dimensions=*/batch_dim_indices); // It is more precise to populate column 'k' explicitly, rather than // computing it implicitly by applying the Householder transformation. // a[k,k] = beta // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) - auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1}); - auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type); - auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type), - std::vector(batch_dims.size(), 1)); - auto new_x = - xla::Mul(x, predecessor_mask, - /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + - xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); + auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); + auto predecessor_mask = ConvertElementType(Lt(iota, j), type); + auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), + std::vector(batch_dims.size(), 1)); + auto new_x = Mul(x, predecessor_mask, + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + + Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); a = DynamicUpdateSliceInMinorDims(a, new_x, {j}); // vs[:, j] = v vs = DynamicUpdateSliceInMinorDims( - vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); + vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); // taus[j] = tau taus = DynamicUpdateSliceInMinorDims( - taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); - return std::vector{a, vs, taus}; + taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); + return std::vector{a, vs, taus}; }; - auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape( - type, ConcatVectors(batch_dims, {m, n}))); - auto taus = xla::Zeros( - builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); + auto vs = Zeros( + builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); + auto taus = Zeros(builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); - TF_ASSIGN_OR_RETURN(auto values, - XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn, - {a, vs, taus}, "qr", builder)); + TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn, + {a, vs, taus}, "qr", builder)); QRBlockResult result; result.r = values[0]; @@ -254,62 +248,58 @@ xla::StatusOr QRBlock( // return W // There is no need to return Y since at termination of the loop it is equal to // vs. -xla::StatusOr ComputeWYRepresentation( - xla::PrimitiveType type, absl::Span batch_dims, xla::XlaOp vs, - xla::XlaOp taus, int64 m, int64 n, - xla::PrecisionConfig::Precision precision) { +StatusOr ComputeWYRepresentation(PrimitiveType type, + absl::Span batch_dims, + XlaOp vs, XlaOp taus, int64 m, int64 n, + PrecisionConfig::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; - auto body_fn = - [&](xla::XlaOp j, absl::Span values, - xla::XlaBuilder* builder) -> xla::StatusOr> { + auto body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { auto w = values[0]; auto y = values[1]; const auto vs = values[2]; const auto taus = values[3]; // Want j values in range [1, ... n). - j = j + xla::ConstantR0(builder, 1); + j = j + ConstantR0(builder, 1); // vs has shape [..., m, 1] auto v = DynamicSliceInMinorDims(vs, {j}, {1}); // beta has shape [..., 1] auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); // yv has shape [..., n, 1] - auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto yv = BatchDot(TransposeInMinorDims(y), v, precision); // wyv has shape [..., m, 1] - auto wyv = - BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto wyv = BatchDot(w, yv, precision); - auto z = xla::Mul( + auto z = Mul( -beta, v + wyv, /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); w = DynamicUpdateSliceInMinorDims(w, z, {j}); y = DynamicUpdateSliceInMinorDims(y, v, {j}); - return std::vector{w, y, vs, taus}; + return std::vector{w, y, vs, taus}; }; - xla::XlaBuilder* builder = vs.builder(); - auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape( - type, ConcatVectors(batch_dims, {m, n}))); + XlaBuilder* builder = vs.builder(); + auto w = Zeros(builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); auto y = w; auto v = SliceInMinorDims(vs, {0}, {1}); auto beta = SliceInMinorDims(taus, {0}, {1}); y = UpdateSliceInMinorDims(y, v, {0}); - auto bv = xla::Mul( - -beta, v, - /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); + auto bv = + Mul(-beta, v, + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); w = UpdateSliceInMinorDims(w, bv, {0}); TF_ASSIGN_OR_RETURN( - auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus}, - "wy", builder)); + auto values, + ForEachIndex(n - 1, S32, body_fn, {w, y, vs, taus}, "wy", builder)); return values[0]; } @@ -330,34 +320,34 @@ xla::StatusOr ComputeWYRepresentation( // return (q, a) // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. -xla::StatusOr QRDecomposition( - xla::XlaOp a, bool full_matrices, int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int num_dims = xla::ShapeUtil::Rank(a_shape); +StatusOr QRDecomposition( + XlaOp a, bool full_matrices, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = a_shape.rank(); if (num_dims < 2) { - return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", - num_dims); + return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s", + a_shape.ToString()); } - xla::PrimitiveType type = a_shape.element_type(); + PrimitiveType type = a_shape.element_type(); - const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 p = std::min(m, n); if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to QR must be >= 1; got ", block_size); + return InvalidArgument("block_size argument to QR must be >= 1; got %d", + block_size); } const int64 num_batch_dims = num_dims - 2; std::vector batch_dims(num_batch_dims); for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); } - auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims); + auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); for (int64 i = 0; i < p; i += block_size) { int64 k = std::min(block_size, p - i); @@ -375,23 +365,15 @@ xla::StatusOr QRDecomposition( // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = - BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - a_update = - BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto a_update = BatchDot(TransposeInMinorDims(w), a_panel, precision); + a_update = BatchDot(y, a_update, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = - BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - q_update = BatchDot(q_update, y, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto q_update = BatchDot(q_panel, w, precision); + q_update = BatchDot(q_update, TransposeInMinorDims(y), precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } @@ -408,4 +390,4 @@ xla::StatusOr QRDecomposition( return result; } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/xla/client/lib/qr.h similarity index 74% rename from tensorflow/compiler/tf2xla/lib/qr.h rename to tensorflow/compiler/xla/client/lib/qr.h index 24b537ac8b63b93e734c3d0e335ea455f7d51a54..827c8eeca05ef09a0d77363eb3c40961b95813d8 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/xla/client/lib/qr.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Computes the QR decompositions of a batch of matrices. That is, // given a (batched) matrix a, computes an orthonormal matrix Q and an @@ -29,14 +29,14 @@ namespace tensorflow { // the block size to use. // TODO(phawkins): handle the complex case. struct QRDecompositionResult { - xla::XlaOp q; - xla::XlaOp r; + XlaOp q; + XlaOp r; }; -xla::StatusOr QRDecomposition( - xla::XlaOp a, bool full_matrices, int64 block_size = 128, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); +StatusOr QRDecomposition( + XlaOp a, bool full_matrices, int64 block_size = 128, + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_ diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b27d364b62444d6d5fb1278b6e6461affc15b2e6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/qr.h" + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using QrTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(QrTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2)); + + // Verifies that the decomposition composes back to the original matrix. + // + // This isn't a terribly demanding test, (e.g., we should verify that Q is + // orthonormal and R is upper-triangular) but it's awkward to write such tests + // without more linear algebra libraries. It's easier to test the numerics + // from Python, anyway, where we have access to numpy and scipy. + xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR2(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(QrTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2)); + + xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST); + + ComputeAndCompareR3(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + +} // namespace diff --git a/tensorflow/compiler/xla/client/lib/quantize.h b/tensorflow/compiler/xla/client/lib/quantize.h new file mode 100644 index 0000000000000000000000000000000000000000..26dbbd5b00bd1a29f4047c9a4294fcac7340cf6c --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/quantize.h @@ -0,0 +1,186 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" + +namespace xla { + +constexpr int64 kBitsOfByte = 8; + +// Represents the range used for quantization +struct QuantizedRange { + QuantizedRange() = default; + QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {} + + bool operator==(const QuantizedRange& rhs) const { + return this->min == rhs.min && this->max == rhs.max; + } + + bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); } + + tensorflow::bfloat16 min = tensorflow::bfloat16(0.0f); + tensorflow::bfloat16 max = tensorflow::bfloat16(0.0f); +}; + +template +inline std::vector PackToUint32(absl::Span input) { + const int64 kElementsPerPack = sizeof(uint32) / sizeof(T); + const int64 input_size = input.size(); + const int64 output_size = CeilOfRatio(input_size, kElementsPerPack); + + std::vector output_vec; + constexpr int64 kShiftBits = sizeof(T) / sizeof(uint8) * kBitsOfByte; + + for (int64 i = 0; i < output_size; i++) { + uint32 result = 0; + for (int64 p = 0; p < kElementsPerPack; p++) { + int64 index = i * kElementsPerPack + p; + if (index < input_size) { + int64 total_shift_bits = kShiftBits * (kElementsPerPack - p - 1); + result |= (input[index] << total_shift_bits); + } + } + output_vec.push_back(result); + } + + return output_vec; +} + +// Dequantize the quantized input of packed uint32 to bfloat16. +// Only uint8 or uint16 is supported for the original unpacked input. +// Returns a tensor of shape [d0,..., dn * unpack_size] if +// input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T). +// If transpose_output is true, will return a tensor of shape +// [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when +// input's rank higher than 1. The input needs to be transposed to use +// transpose_output feature. +template +inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, + absl::string_view mode_string = "MIN_COMBINED", + bool transpose_output = false) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max()) - + std::numeric_limits::min() + 1) / + 2.0f; + const int64 unpack_size = sizeof(uint32) / sizeof(T); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input)); + + auto element_type = shape.element_type(); + if (element_type != U32) { + return InvalidArgument( + "Only U32 is supported for input type of xla::Dequantize Op."); + } + + // Broadcast the input to [unpack_size, d0, ..., dn] if input size is + // [d0, ..., dn]. + auto broadcast_input = Broadcast(input, {unpack_size}); + + XlaOp iota_r1 = Iota(builder, U32, unpack_size); + // Highest significant bytes needs to shift more bytes than lower + // significant bytes. + XlaOp shift_bytes = + xla::ConstantR0(builder, unpack_size - 1) - iota_r1; + + const int bytes_of_type = sizeof(T) / sizeof(uint8); + std::vector shift_vec(unpack_size, kBitsOfByte * bytes_of_type); + XlaOp shift_bits = + shift_bytes * xla::ConstantR1(builder, shift_vec); + + // Make bit_mask for different data type T. + uint32 bit_mask = 0x00000000; + for (int i = 0; i < bytes_of_type; i++) { + bit_mask <<= kBitsOfByte; + bit_mask |= 0x000000ff; + } + + std::vector shift_transpose_dimensions(shape.dimensions_size()); + std::iota(shift_transpose_dimensions.begin(), + shift_transpose_dimensions.end(), 0); + shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1, + shape.dimensions_size()); + + // Shift the input by sizeof(T) bytes and apply bit_mask to unpack. + XlaOp shifted_input = ShiftRightLogical( + broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()), + shift_transpose_dimensions)); + XlaOp unpack_input = + And(shifted_input, xla::ConstantR0(builder, bit_mask)); + + XlaOp result; + + if (mode_string == "MIN_COMBINED") { + const tensorflow::bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + // result = bfloat16(input + half_range) * scale_factor + range.min + XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16); + XlaOp half_range_bf16 = xla::ConstantR0( + builder, static_cast(half_range)); + XlaOp sum = unpack_input_bf16 + half_range_bf16; + + result = + sum * xla::ConstantR0(builder, scale_factor) + + xla::ConstantR0(builder, range.min); + } else { + // TODO(wangtao): support other modes. + return InvalidArgument( + "Only MIN_COMBINED mode is supported in xla::Dequantize Op."); + } + + std::vector transpose_dimensions(shape.dimensions_size()); + std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1); + std::reverse(transpose_dimensions.begin(), transpose_dimensions.end()); + transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0); + + // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0]. + XlaOp transposed_result = Transpose(result, transpose_dimensions); + + // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0]. + XlaOp reshaped_result = Collapse(transposed_result, {0, 1}); + + // Return the transpose result if transpose_output is true. + if (transpose_output) { + return reshaped_result; + } + + // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size]. + std::vector result_dimensions(shape.dimensions_size()); + std::iota(result_dimensions.begin(), result_dimensions.end(), 0); + std::reverse(result_dimensions.begin(), result_dimensions.end()); + + return Transpose(reshaped_result, result_dimensions); + }); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_ diff --git a/tensorflow/compiler/xla/client/lib/quantize_test.cc b/tensorflow/compiler/xla/client/lib/quantize_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..be3603d9e11670913c21a834d2216a999306d582 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/quantize_test.cc @@ -0,0 +1,337 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/quantize.h" + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace { + +using bfloat16 = tensorflow::bfloat16; + +template +std::vector GenerateInput() { + std::vector input; + + for (int64 i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + input.push_back(static_cast(i)); + } + + return input; +} + +template +Array2D GenerateLargeSizeInput(int num_columns, int num_rows) { + Array2D input(num_columns, num_rows); + + input.FillRandom(6, 128); + + return input; +} + +template +Array2D PackLargeInput(Array2D &input) { + const int64 size_per_pack = sizeof(uint32) / sizeof(NativeT); + int64 width = input.width(); + + int64 padded_output_width = CeilOfRatio(width, size_per_pack); + + Array2D pack_input(input.height(), padded_output_width); + + for (int h = 0; h < input.height(); h++) { + std::vector input_row; + for (int w = 0; w < width; w++) { + input_row.push_back(input({h, w})); + } + + auto pack_input_vec = PackToUint32(input_row); + + for (int w = 0; w < padded_output_width; w++) { + pack_input(h, w) = pack_input_vec[w]; + } + } + + return pack_input; +} + +template +Array2D GenerateLargeSizeMinCombinedOutput( + Array2D &input, const QuantizedRange &range, + bool transpose_output = false) { + const int64 size_per_pack = sizeof(uint32) / sizeof(NativeT); + int64 width = input.width(); + + int64 padded_output_width = CeilOfRatio(width, size_per_pack) * size_per_pack; + + int64 output_height; + int64 output_width; + + if (transpose_output) { + output_height = padded_output_width; + output_width = input.height(); + } else { + output_height = input.height(); + output_width = padded_output_width; + } + + Array2D output(output_height, output_width, bfloat16(0.0)); + + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max() - + std::numeric_limits::min() + 1)) / + 2.0f; + const bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + + for (int h = 0; h < input.height(); h++) { + std::vector input_row; + for (int w = 0; w < width; w++) { + bfloat16 result = + static_cast(input(h, w) + half_range) * scale_factor + + range.min; + if (transpose_output) { + output(w, h) = result; + } else { + output(h, w) = result; + } + } + } + + return output; +} + +template +std::vector GenerateMinCombinedOutput(const QuantizedRange &range) { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max() - + std::numeric_limits::min() + 1)) / + 2.0f; + const bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + std::vector output; + for (int64 i = std::numeric_limits::min(); + i < std::numeric_limits::max(); ++i) { + bfloat16 result = + static_cast(i + half_range) * scale_factor + range.min; + output.push_back(result); + } + + const int64 pack_size = sizeof(uint32) / sizeof(NativeT); + const int64 output_size = output.size(); + + int64 num_tailing_zeros = + CeilOfRatio(output_size, pack_size) * pack_size - output_size; + + output.insert(output.end(), num_tailing_zeros, bfloat16(0.0)); + return output; +} + +// TODO(wangtao): add a test to make sure this op is the inverse of the existing +// TF quantize op defined in: third_party/tensorflow/core/kernels/quantize_op.cc + +using DequantizeTest = ClientLibraryTestBase; + +TEST(PackTest, PackUint8ToUint32) { + std::vector input = {0xAB, 0x0B, 0x00, 0xF0, 0x01}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0xAB0B00F0, 0x01000000)); +} + +TEST(PackTest, PackInt8ToUint32) { + std::vector input = {static_cast(0x81), 0x0B, 0x00, 0x20, + 0x01}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0x810B0020, 0x01000000)); +} + +TEST(PackTest, PackUint8ToUint32PerfectSize) { + std::vector input = {3, 2, 1, 0}; + auto output = PackToUint32(input); + EXPECT_THAT(output, ::testing::ElementsAre(0x03020100)); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint16R1) { + XlaBuilder builder(TestName()); + auto input = GenerateInput(); + auto x = ConstantR1(&builder, PackToUint32(input)); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + auto expected = GenerateMinCombinedOutput(range); + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R1) { + XlaBuilder builder(TestName()); + auto input = GenerateInput(); + auto x = ConstantR1(&builder, PackToUint32(input)); + QuantizedRange range(0, 127.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + auto expected = GenerateMinCombinedOutput(range); + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11}, + {12, 13, 16, 15}, + }; + auto x = ConstantR2(&builder, {{PackToUint32(input[0])[0]}, + {PackToUint32(input[1])[0]}, + {PackToUint32(input[2])[0]}, + {PackToUint32(input[3])[0]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + const Array2D expected = { + {bfloat16(0.0), bfloat16(1.0), bfloat16(2.0), bfloat16(3.0)}, + {bfloat16(4.0), bfloat16(5.0), bfloat16(6.0), bfloat16(7.0)}, + {bfloat16(8.0), bfloat16(9.0), bfloat16(10.0), bfloat16(11.0)}, + {bfloat16(12.0), bfloat16(13.0), bfloat16(16.0), bfloat16(15.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TransposeOutput) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 10, 11}, + {12, 13, 16, 15}, + }; + auto x = ConstantR2(&builder, {{PackToUint32(input[0])[0]}, + {PackToUint32(input[1])[0]}, + {PackToUint32(input[2])[0]}, + {PackToUint32(input[3])[0]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + const Array2D expected = { + {bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)}, + {bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)}, + {bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)}, + {bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZero) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3, 16}, + {4, 5, 6, 7, 17}, + {8, 9, 10, 11, 18}, + {12, 13, 16, 15, 19}, + }; + auto x = ConstantR2( + &builder, + {{PackToUint32(input[0])[0], PackToUint32(input[0])[1]}, + {PackToUint32(input[1])[0], PackToUint32(input[1])[1]}, + {PackToUint32(input[2])[0], PackToUint32(input[2])[1]}, + {PackToUint32(input[3])[0], PackToUint32(input[3])[1]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + + const Array2D expected = { + {bfloat16(0.0), bfloat16(1.0), bfloat16(2.0), bfloat16(3.0), + bfloat16(16.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(4.0), bfloat16(5.0), bfloat16(6.0), bfloat16(7.0), + bfloat16(17.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(8.0), bfloat16(9.0), bfloat16(10.0), bfloat16(11.0), + bfloat16(18.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(12.0), bfloat16(13.0), bfloat16(16.0), bfloat16(15.0), + bfloat16(19.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZeroTransposeOutput) { + XlaBuilder builder(TestName()); + std::vector> input = { + {0, 1, 2, 3, 16}, + {4, 5, 6, 7, 17}, + {8, 9, 10, 11, 18}, + {12, 13, 16, 15, 19}, + }; + auto x = ConstantR2( + &builder, + {{PackToUint32(input[0])[0], PackToUint32(input[0])[1]}, + {PackToUint32(input[1])[0], PackToUint32(input[1])[1]}, + {PackToUint32(input[2])[0], PackToUint32(input[2])[1]}, + {PackToUint32(input[3])[0], PackToUint32(input[3])[1]}}); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + + const Array2D expected = { + {bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)}, + {bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)}, + {bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)}, + {bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)}, + {bfloat16(16.0), bfloat16(17.0), bfloat16(18.0), bfloat16(19.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)}, + }; + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTest) { + XlaBuilder builder(TestName()); + Array2D input = GenerateLargeSizeInput(500, 3547); + Array2D input_packed = PackLargeInput(input); + + auto x = ConstantR2FromArray2D(&builder, input_packed); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED"); + + const Array2D expected = + GenerateLargeSizeMinCombinedOutput(input, range); + ComputeAndCompareR2(&builder, expected, {}); +} + +XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTestTransposeOutput) { + XlaBuilder builder(TestName()); + Array2D input = GenerateLargeSizeInput(500, 3547); + Array2D input_packed = PackLargeInput(input); + + auto x = ConstantR2FromArray2D(&builder, input_packed); + QuantizedRange range(0, 255.0f); + xla::Dequantize(x, range, "MIN_COMBINED", /*transpose_output=*/true); + + const Array2D expected = GenerateLargeSizeMinCombinedOutput( + input, range, /*transpose_output=*/true); + ComputeAndCompareR2(&builder, expected, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc new file mode 100644 index 0000000000000000000000000000000000000000..77145ba7d4c72435450d3e33d57b2507eb84d2fc --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -0,0 +1,137 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" + +namespace xla { + +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + + const int64 n_dims = shape.rank(); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return Slice(x, padded_start, padded_end, strides); + }); +} + +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + TF_RET_CHECK(start.size() == n_dims); + + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + std::vector start_ops(start.size()); + for (int i = 0; i < start.size(); ++i) { + start_ops[i] = ConstantR0(builder, start_as_int32[i]); + } + return DynamicUpdateSlice(x, update, start_ops); + }); +} + +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(x, update, padded_start); + }); +} + +namespace { + +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); + return output; +} + +StatusOr> PrependZerosInMajorDims( + XlaOp x, absl::Span starts) { + XlaBuilder* builder = x.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + auto zero = ConstantR0(builder, 0); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = starts[i]; + } + return padded_starts; +} + +} // namespace + +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = shape.rank(); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts)); + auto padded_sizes = ConcatVectors(major_dims, sizes); + return DynamicSlice(x, padded_starts, padded_sizes); + }); +} + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts)); + return DynamicUpdateSlice(x, update, padded_starts); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h new file mode 100644 index 0000000000000000000000000000000000000000..6c482a38b5489c9fb17c3dca9ee3d2a1b8fd1890 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ + +namespace xla { + +// Updates a slice of 'x', i.e., +// x[start[0], ..., start[n]] = update +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start); + +// Performs a slice in the minor dimensions of a tensor. +// x[..., start[0]:end[0], ..., start[n]:end[n]] +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end); + +// Updates a slice of 'x', where 'start' contains a list of minor dimensions: +// x[..., start[0]:..., ..., start[n]:...] = update +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start); + +// Performs a dynamic slice in the minor dimensions of a tensor. +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes); + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc similarity index 67% rename from tensorflow/compiler/tf2xla/lib/util_test.cc rename to tensorflow/compiler/xla/client/lib/slicing_test.cc index 442fe92c34ca26cb1a854cc90da8dc034bca79bb..8d362119e01006555db0f82d02626175936e1d05 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -13,28 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" -#include -#include -#include - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/status_test_util.h" -namespace tensorflow { +namespace xla { namespace { -using UtilTest = xla::ClientLibraryTestBase; -using UtilLeftLookingTest = xla::ClientLibraryTestBase; +using SlicingTest = xla::ClientLibraryTestBase; xla::Array2D BValsRight() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; @@ -63,7 +54,7 @@ xla::Array3D BatchedAValsFull() { }}; } -XLA_TEST_F(UtilTest, Simple2dLookup) { +XLA_TEST_F(SlicingTest, Simple2dLookup) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, x, y; @@ -77,7 +68,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { xla::ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(UtilTest, Simple3dLookup) { +XLA_TEST_F(SlicingTest, Simple3dLookup) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, index; @@ -92,7 +83,7 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { {a_data.get(), index_data.get()}); } -XLA_TEST_F(UtilTest, SimpleSliceUpdate) { +XLA_TEST_F(SlicingTest, SimpleSliceUpdate) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, b, x, y; @@ -111,26 +102,5 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); } -XLA_TEST_F(UtilTest, RowBatchDot) { - xla::XlaBuilder builder(TestName()); - - int n = 4; - - xla::XlaOp a, row, index; - auto a_data = - CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); - auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, - "row", &builder, &row); - // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). - auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); - - auto l_index = DynamicSliceInMinorDims( - a, {index, xla::ConstantR0(&builder, 0)}, {1, n}); - BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true); - - ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, - {a_data.get(), row_data.get(), index_data.get()}); -} - } // namespace -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index 27ff36c7491ab8397d46f3a49493ff2b904deb2d..0fbd138aca1e86f219d0459086fc09d20844f135 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -77,7 +77,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) { auto x = ConstantR1(&builder, inputs); xla::GetTupleElement(xla::TopK(x, kSize), 0); - std::sort(inputs.begin(), inputs.end(), std::greater()); + absl::c_sort(inputs, std::greater()); ComputeAndCompareR1(&builder, inputs, {}); } diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index a95bbf2c8c860914877d3195b97342097dafc725..9f520bcdadfabc8ca9f9ee82b20804fd2c50d1db 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -34,7 +34,7 @@ namespace { // specified shape. In case of a (nested) tuple shape this is the total byte // size of all sub-shapes within the tuple. int64 DataSizeOfShape(const Shape& shape) { - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { return ShapeUtil::ByteSizeOf(shape); } @@ -47,7 +47,7 @@ int64 DataSizeOfShape(const Shape& shape) { // Creates a XlaOp for an op what generates fake data with the given shape. XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { return Broadcast( ConstantLiteral(builder, LiteralUtil::One(shape.element_type())), AsInt64Slice(shape.dimensions())); @@ -59,22 +59,25 @@ XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { return Tuple(builder, parts); } -std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, - Client* client) { +std::unique_ptr MakeFakeDataViaDeviceOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts) { XlaBuilder b(absl::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape.ToProto(); + if (debug_opts) { + *execution_options.mutable_debug_options() = *debug_opts; + } return client->Execute(computation, /*arguments=*/{}, &execution_options) .ConsumeValueOrDie(); } } // namespace -std::unique_ptr MakeFakeDataOrDie(const Shape& shape, - Client* client) { +std::unique_ptr MakeFakeDataOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts /*=nullptr*/) { if (DataSizeOfShape(shape) < (1LL << 20)) { StatusOr literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { @@ -82,24 +85,25 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, // an on-device computation. CHECK_EQ(literal_status.status().code(), tensorflow::error::UNIMPLEMENTED); - return MakeFakeDataViaDeviceOrDie(shape, client); + return MakeFakeDataViaDeviceOrDie(shape, client, debug_opts); } return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie(); } // If the data is large, generate it on-device. - return MakeFakeDataViaDeviceOrDie(shape, client); + return MakeFakeDataViaDeviceOrDie(shape, client, debug_opts); } std::vector> MakeFakeArgumentsOrDie( - const XlaComputation& computation, Client* client) { + const XlaComputation& computation, Client* client, + DebugOptions* debug_opts /*=nullptr*/) { CHECK(computation.proto().has_host_program_shape()) << "Computation should have progran shape."; auto program_shape = computation.proto().host_program_shape(); std::vector> results; for (const ShapeProto& shape : program_shape.parameters()) { - results.push_back(MakeFakeDataOrDie(Shape(shape), client)); + results.push_back(MakeFakeDataOrDie(Shape(shape), client, debug_opts)); } return results; } diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 03695ce2a339735e3e49522f4fe1bbf2d83a3834..428fa3e93d1b46983aae60176e7c2242d2552fdb 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -29,14 +29,19 @@ namespace xla { // Generates fake data of the given shape on the device or dies. The fake data // is created by performing a computation on the device rather than transferring // data from the host to the device. -std::unique_ptr MakeFakeDataOrDie(const Shape& shape, - Client* client); +// +// The optional DebugOptions are used when generating fake data on the device. +std::unique_ptr MakeFakeDataOrDie( + const Shape& shape, Client* client, DebugOptions* debug_opts = nullptr); // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given // xla computation. +// +// The optional DebugOptions are used when generating fake data on the device. std::vector> MakeFakeArgumentsOrDie( - const XlaComputation& computation, Client* client); + const XlaComputation& computation, Client* client, + DebugOptions* debug_opts = nullptr); } // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/xla/client/lib/triangular_solve.cc similarity index 60% rename from tensorflow/compiler/tf2xla/lib/triangular_solve.cc rename to tensorflow/compiler/xla/client/lib/triangular_solve.cc index 6524c2a9b1ada632d80edd234272760c2b545cc4..ba7fde118fde990fbb4aa9a34dd0f0e67ff5a93b 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -29,21 +29,20 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/math/math_util.h" -namespace tensorflow { +namespace xla { // Get the diagonal blocks of the coefficient matrix -xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(a)); - int ndims = xla::ShapeUtil::Rank(shape); - int64 n = xla::ShapeUtil::GetDimension(shape, -1); +XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a)); + int ndims = shape.rank(); + int64 n = ShapeUtil::GetDimension(shape, -1); int64 num_blocks = n / block_size; - xla::XlaOp diag_blocks; + XlaOp diag_blocks; // If the coefficient matrix is exactly the block size, we just add a // singleton dimension i.e. [..., n, n] -> [..., 1, n, n] @@ -58,20 +57,31 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { if (n > block_size) { // Construct the starting indices of the diagonal blocks auto start_indices = - Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), - xla::ConstantR0(builder, block_size)), + Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks), + ConstantR0(builder, block_size)), /*broadcast_sizes=*/{2}), /*permutation=*/{1, 0}); + PaddingConfig padding_config = + MakeEdgePaddingConfig({{0, 0}, {ndims - 2, 0}}); + start_indices = + Pad(start_indices, ConstantR0(builder, 0), padding_config); + // Gather the diagonal blocks - xla::GatherDimensionNumbers dim_numbers; + std::vector slice_sizes(ndims); + GatherDimensionNumbers dim_numbers; + for (int i = 0; i < ndims - 2; ++i) { + dim_numbers.add_offset_dims(i); + dim_numbers.add_start_index_map(i); + slice_sizes[i] = ShapeUtil::GetDimension(shape, i); + } + slice_sizes[ndims - 2] = slice_sizes[ndims - 1] = block_size; dim_numbers.add_offset_dims(ndims - 1); dim_numbers.add_offset_dims(ndims); dim_numbers.add_start_index_map(ndims - 2); dim_numbers.add_start_index_map(ndims - 1); dim_numbers.set_index_vector_dim(1); - diag_blocks = Gather(a, start_indices, dim_numbers, - /*slice_sizes=*/{block_size, block_size}); + diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes); } // The last block might be smaller than the block size, @@ -80,7 +90,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Pad with zeros auto last_blocks = SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); - xla::PaddingConfig config = xla::MakeNoPaddingConfig(ndims); + PaddingConfig config = MakeNoPaddingConfig(ndims); int64 padding = block_size - n % block_size; config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); @@ -89,9 +99,8 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Add a singleton dimension // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] - TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, - builder->GetShape(last_blocks)); - auto shape_dims = xla::AsInt64Slice(blocks_shape.dimensions()); + TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks)); + auto shape_dims = AsInt64Slice(blocks_shape.dimensions()); auto last_blocks_dims = std::vector(ndims); std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); @@ -100,7 +109,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Concatenate with the other blocks if necessary if (n > block_size) { diag_blocks = - xla::ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); + ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); } else { diag_blocks = last_blocks; } @@ -110,16 +119,16 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { }); } -xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, - bool transpose_a, bool conjugate_a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = diag_blocks.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { +XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, + bool conjugate_a, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = diag_blocks.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is // (..., size, size). We resize this to (num_blocks, size, size). - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(diag_blocks)); - int64 block_size = xla::ShapeUtil::GetDimension(shape, -1); - int64 num_blocks = xla::ShapeUtil::ElementsIn(shape) / + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); + int64 block_size = ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = ShapeUtil::ElementsIn(shape) / tensorflow::MathUtil::IPow(block_size, 2); diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); @@ -131,9 +140,7 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, // zero (which can happen if the last block was padded) otherwise it will // introduce nans which will propagate auto diags = GetMatrixDiagonal(diag_blocks); - TF_ASSIGN_OR_RETURN(xla::Shape diags_shape, builder->GetShape(diags)); - auto one = ScalarLike(diags, 1); - auto ones = Broadcast(one, xla::AsInt64Slice(diags_shape.dimensions())); + auto ones = FullLike(diags, 1); diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); @@ -156,43 +163,43 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, // The first or last diagonal element should be set to 1 instead of -1 // though, since we never update it auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); - auto start_index = (lower) ? 0 : block_size - 1; - auto output_block = DynamicUpdateSlice( - neg_identity, pos_one, - /*start_indices=*/xla::ConstantR1(builder, 2, start_index)); + auto start_index = ConstantR0(builder, (lower) ? 0 : block_size - 1); + auto output_block = + DynamicUpdateSlice(neg_identity, pos_one, + /*start_indices=*/{start_index, start_index}); // Broadcast diag([1, -1, -1, ...]) to every block - xla::XlaOp output = Broadcast(output_block, - /*broadcast_sizes=*/{num_blocks}); + XlaOp output = Broadcast(output_block, + /*broadcast_sizes=*/{num_blocks}); // Now we construct a loop that performs matrix-vector multiplications // inverting the blocks one row at a time - std::vector tuple_shapes = { + std::vector tuple_shapes = { // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), + ShapeUtil::MakeShape(S32, {}), // The output has the shape of A, with one row updated each iteration. - xla::ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size}), + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size}), // The input is a loop invariant. - xla::ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size})}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size})}; + Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = One(builder, xla::S32); - auto init = xla::Tuple(builder, {init_i, output, scaled_diag_blocks}); + auto init_i = One(builder, S32); + auto init = Tuple(builder, {init_i, output, scaled_diag_blocks}); // Construct the loop condition function. - std::unique_ptr condb = + std::unique_ptr condb = builder->CreateSubBuilder("InvertDiagCond"); { auto i = GetTupleElement( Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); - Lt(i, xla::ConstantR0(condb.get(), block_size)); + Lt(i, ConstantR0(condb.get(), block_size)); } TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); // Construct the loop body function. - std::unique_ptr bodyb = + std::unique_ptr bodyb = builder->CreateSubBuilder("InvertDiagBody"); { auto input_tuple = @@ -202,29 +209,27 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, auto body_out = GetTupleElement(input_tuple, 1); auto body_input = GetTupleElement(input_tuple, 2); - auto zero = xla::ConstantR1(bodyb.get(), 1, 0); + auto zero = ConstantR0(bodyb.get(), 0); auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; - auto start_indices = - xla::ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); auto input_row = - DynamicSlice(body_input, start_indices, + DynamicSlice(body_input, {zero, j, zero}, /*slice_sizes=*/{num_blocks, 1, block_size}); // We want -L21 L11^{-1} - xla::DotDimensionNumbers dnums; + DotDimensionNumbers dnums; dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - xla::PrecisionConfig precision_proto; + PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); - body_out = DynamicUpdateSlice(body_out, update, start_indices); + body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero}); auto next_i = i + ScalarLike(i, 1); - xla::Tuple(bodyb.get(), {next_i, body_out, body_input}); + Tuple(bodyb.get(), {next_i, body_out, body_input}); } TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); @@ -238,27 +243,26 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, /*broadcast_dimensions=*/{0, 1}); // Reshape back to original batch major dimensions - return Reshape(inv_diag_blocks, xla::AsInt64Slice(shape.dimensions())); + return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions())); }); } -xla::XlaOp SolveWithInvertedDiagonalBlocks( - xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, - builder->GetShape(inv_diag_blocks)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - int64 block_size = xla::ShapeUtil::GetDimension(blocks_shape, -1); - - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - int64 ndims = xla::ShapeUtil::Rank(a_shape); - int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); +XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks)); + TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1); + + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + int64 ndims = a_shape.rank(); + int64 n = ShapeUtil::GetDimension(a_shape, -1); int64 num_blocks = n / block_size + (n % block_size != 0); int64 m_dim = (left_side) ? -1 : -2; - int64 m = xla::ShapeUtil::GetDimension(b_shape, m_dim); + int64 m = ShapeUtil::GetDimension(b_shape, m_dim); // Initialize the solution auto x = ZerosLike(b); @@ -294,7 +298,7 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( } auto b_row = SliceInMinorDims(b, start, end); - xla::XlaOp remainder; + XlaOp remainder; if (i == 0) { remainder = b_row; } else { @@ -311,29 +315,27 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( auto a_row = MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); if (left_side) { - remainder = b_row - BatchDot(a_row, x, transpose_a, false, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + remainder = + b_row - BatchDot(MaybeTransposeInMinorDims(a_row, transpose_a), x, + precision); } else { - remainder = b_row - BatchDot(x, a_row, false, transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + remainder = + b_row - BatchDot(x, MaybeTransposeInMinorDims(a_row, transpose_a), + precision); } } - xla::XlaOp x_update; - auto zero = Zero(builder, xla::S32); - auto start_index = - xla::ConstantR0WithType(builder, xla::S32, j * block_size); - std::vector update_starts = {start_index, zero}; + XlaOp x_update; + auto zero = Zero(builder, S32); + auto start_index = ConstantR0WithType(builder, S32, j * block_size); + std::vector update_starts = {start_index, zero}; if (left_side) { - x_update = - BatchDot(inv_block, remainder, transpose_a, false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + x_update = BatchDot(MaybeTransposeInMinorDims(inv_block, transpose_a), + remainder, precision); } else { - x_update = - BatchDot(remainder, inv_block, false, transpose_a, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + x_update = BatchDot(remainder, + MaybeTransposeInMinorDims(inv_block, transpose_a), + precision); std::swap(update_starts[0], update_starts[1]); } x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); @@ -343,24 +345,24 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( }); } -xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(a_shape), " vs. ", - xla::ShapeUtil::HumanString(b_shape)); +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool transpose_a, bool conjugate_a, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + if (a_shape.rank() != b_shape.rank()) { + return InvalidArgument( + "Arguments to TriangularSolve have shapes with different ranks: " + "%s vs. %s", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } - const int64 ndims = xla::ShapeUtil::Rank(a_shape); + const int64 ndims = a_shape.rank(); if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to TriangularSolve must have rank >= 2: ", ndims); + return InvalidArgument( + "Arguments to TriangularSolve was rank %d but must have rank >= 2.", + ndims); } // The batch dimensions must be equal. std::vector batch_dimensions; @@ -368,35 +370,42 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, int64 a_size = a_shape.dimensions(i); int64 b_size = b_shape.dimensions(i); if (a_size != b_size) { - return errors::InvalidArgument( - "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); + return InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal; " + "shapes were %s and %s.", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } batch_dimensions.push_back(a_size); } - if (xla::ShapeUtil::GetDimension(a_shape, -1) != - xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); + if (ShapeUtil::GetDimension(a_shape, -1) != + ShapeUtil::GetDimension(a_shape, -2)) { + return InvalidArgument( + "The 'a' argument to TriangularSolve must be a batched square matrix;" + " shape was: %s", + ShapeUtil::HumanString(a_shape)); } - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); + const int64 m = ShapeUtil::GetDimension(b_shape, -2); + const int64 n = ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) { + return InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes %s and " + "%s", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to TriangularSolve must be >= 1; got ", + return InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got %d", block_size); } + if (ShapeUtil::IsZeroElementArray(b_shape)) { + // The output has the same shape as 'b', and since the output has zero + // elements, any such array will do. + return b; + } + // We find the diagonal blocks of the coefficient matrix auto diag_blocks = DiagonalBlocks(a, block_size); @@ -404,6 +413,11 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a, precision); + // Mask off the ignored elements of the triangular matrix a. + // TODO(phawkins): it would probably be preferable to perform this masking + // block by block inside SolveWithInvertedDiagonalBlocks. + a = Triangle(a, lower); + // We now find the solution using GEMMs auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower, @@ -413,4 +427,4 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, }); } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/xla/client/lib/triangular_solve.h similarity index 88% rename from tensorflow/compiler/tf2xla/lib/triangular_solve.h rename to tensorflow/compiler/xla/client/lib/triangular_solve.h index 2303234f361e54cd2a0ad495cb03b371bed76877..50a3b30ebd1c15eb6d2ace4e351cb41f21db7093 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Solves systems of linear equations with lower or upper triangular coefficient // matrices by forward- or back-substitution. Broadcasting along leading @@ -57,11 +57,11 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::XlaOp TriangularSolve( - xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, +XlaOp TriangularSolve( + XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, int64 block_size = 128, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc similarity index 58% rename from tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc rename to tensorflow/compiler/xla/client/lib/triangular_solve_test.cc index aeebf16028d40189203cdfd815f06a339ee72902..284a2e9d183a6a7923fb59ac134ce3b3a3a96e35 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include #include #include #include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" @@ -30,59 +32,81 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" -namespace tensorflow { +namespace xla { namespace { -using TriangularSolveTest = xla::ClientLibraryTestBase; -using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase; -using complex64 = xla::complex64; +using TriangularSolveTest = ClientLibraryTestBase; +using TriangularSolveLeftLookingTest = ClientLibraryTestBase; -xla::Array2D AValsLower() { - return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +static constexpr float kNan = std::numeric_limits::quiet_NaN(); + +Array2D AValsLower() { + return {{2, kNan, kNan, kNan}, + {3, 6, kNan, kNan}, + {4, 7, 9, kNan}, + {5, 8, 10, 11}}; } -xla::Array2D AValsUpper() { - return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}}; +Array2D AValsUpper() { + return {{2, 3, 4, 5}, + {kNan, 6, 7, 8}, + {kNan, kNan, 9, 10}, + {kNan, kNan, kNan, 11}}; } -xla::Array2D BValsRight() { +Array2D BValsRight() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; } -xla::Array2D BValsLeft() { +Array2D BValsLeft() { return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; } -xla::Array2D AValsLowerComplex() { - return {{2, 0, 0, 0}, - {complex64(3, 1), 6, 0, 0}, - {4, complex64(7, 2), 9, 0}, +static constexpr complex64 kNanC64 = complex64(kNan, kNan); + +Array2D AValsLowerComplex() { + return {{2, kNanC64, kNanC64, kNanC64}, + {complex64(3, 1), 6, kNanC64, kNanC64}, + {4, complex64(7, 2), 9, kNanC64}, {5, 8, complex64(10, 3), 11}}; } -xla::Array2D AValsUpperComplex() { +Array2D AValsUpperComplex() { return {{2, 3, complex64(4, 3), 5}, - {0, 6, complex64(7, 2), 8}, - {0, 0, complex64(9, 1), 10}, - {0, 0, 0, 11}}; + {kNanC64, 6, complex64(7, 2), 8}, + {kNanC64, kNanC64, complex64(9, 1), 10}, + {kNanC64, kNanC64, kNanC64, 11}}; } -xla::Array2D BValsRightComplex() { +Array2D BValsRightComplex() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; } -xla::Array2D BValsLeftComplex() { +Array2D BValsLeftComplex() { return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; } -xla::Array2D AValsFull() { - return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +XLA_TEST_F(TriangularSolveTest, EmptyArrays) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(Array2D(0, 0), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(Array2D(0, 10), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + + ComputeAndCompareR2(&builder, Array2D(0, 10), + {a_data.get(), b_data.get()}); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -90,20 +114,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, {2.5, -0.25, -0.1388889, -0.1010101}, {4.5, -0.58333331, -0.32407406, -0.23569024}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -111,20 +135,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, {0.64393939, 0.06565657, -0.03030303, 0.72727273}, {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -132,20 +156,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, {0.64393939, 0.06565657, -0.03030303, 0.72727273}, {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -153,20 +177,20 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, {2.5, -0.25, -0.1388889, -0.1010101}, {4.5, -0.58333331, -0.32407406, -0.23569024}, }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -174,7 +198,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, {-0.27441077, -0.24074074, -0.20707071}, {-0.23232323, -0.22222222, -0.21212121}, @@ -182,13 +206,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -196,7 +220,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -204,13 +228,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -218,7 +242,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/3); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -226,13 +250,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -240,7 +264,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1.0, 1.5}, {0.41666667, 0.33333333, 0.25}, {0.23148148, 0.18518519, 0.13888889}, @@ -248,13 +272,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, @@ -262,7 +286,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { /*transpose_a=*/false, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, {-0.27441077, -0.24074074, -0.20707071}, {-0.23232323, -0.22222222, -0.21212121}, @@ -270,13 +294,13 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = @@ -286,7 +310,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { /*transpose_a=*/true, /*conjugate_a=*/true, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, complex64(0.08333333, 0.08333333), complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)}, {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963), @@ -295,15 +319,14 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { complex64(0.11026936, -0.03114478)}, }); - ComputeAndCompareR2(&builder, expected, - {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ComputeAndCompareR2( + &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::XlaOp a, b; + XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = @@ -313,7 +336,7 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { /*transpose_a=*/true, /*conjugate_a=*/false, /*block_size=*/2); - xla::Array2D expected({ + Array2D expected({ {0.5, 1., 1.5}, {0.41666667, 0.33333333, 0.25}, {complex64(0.20020325, -2.81504065e-01), @@ -324,10 +347,101 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { complex64(0.15798226, 5.12749446e-01)}, }); - ComputeAndCompareR2(&builder, expected, - {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); + ComputeAndCompareR2( + &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); } +XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) { + XlaBuilder builder(TestName()); + + Array3D bvals(7, 5, 5); + bvals.FillIota(1.); + + // Set avals to the upper triangle of bvals. + Array3D avals = bvals; + avals.Each([](absl::Span indices, float* value) { + if (indices[1] > indices[2]) { + *value = 0; + } + }); + + XlaOp a, b; + auto a_data = CreateR3Parameter(avals, 0, "a", &builder, &a); + auto b_data = CreateR3Parameter(bvals, 1, "b", &builder, &b); + BatchDot(ConstantR3FromArray3D(&builder, avals), + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2)); + + ComputeAndCompareR3(&builder, bvals, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +struct TriangularSolveTestSpec { + int m, n; // A is mxm, B is mxn + bool left_side; + bool lower; + bool transpose_a; +}; + +class TriangularSolveParametricTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(TriangularSolveParametricTest, Random) { + TriangularSolveTestSpec spec = GetParam(); + + XlaBuilder builder(TestName()); + + Array2D avals(spec.m, spec.m); + avals.FillRandom(1.0); + for (int i = 0; i < spec.m; ++i) { + avals(i, i) += 10; + } + + std::pair bdims = spec.left_side ? std::make_pair(spec.m, spec.n) + : std::make_pair(spec.n, spec.m); + Array2D bvals(bdims.first, bdims.second); + bvals.FillRandom(1.0); + + XlaOp a, b; + auto a_data = CreateR2Parameter(avals, 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(bvals, 1, "b", &builder, &b); + auto x = TriangularSolve(a, b, spec.left_side, spec.lower, spec.transpose_a, + /*conjugate_a=*/false, + /*block_size=*/3); + auto a_tri = Triangle(a, spec.lower); + a_tri = MaybeTransposeInMinorDims(a_tri, spec.transpose_a); + if (spec.left_side) { + BatchDot(a_tri, x); + } else { + BatchDot(x, a_tri); + } + + ComputeAndCompareR2(&builder, bvals, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + +std::vector TriangularSolveTests() { + std::vector specs; + for (int m : {5, 10}) { + for (int n : {5, 10}) { + for (bool left_side : {false, true}) { + for (bool lower : {false, true}) { + for (bool transpose_a : {false, true}) { + specs.push_back({m, n, left_side, lower, transpose_a}); + } + } + } + } + } + return specs; +} + +INSTANTIATE_TEST_SUITE_P(TriangularSolveParametricTestInstantiation, + TriangularSolveParametricTest, + ::testing::ValuesIn(TriangularSolveTests())); + } // namespace -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index aaa5d6989eefb94edb8921d13f96e3705aa3e3a4..48b5f94538f453785194bc434a91ee0a10c020c2 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -71,9 +71,9 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString( + ShapeUtil::HumanStringWithLayout( computation_layout.parameter_layout(i).shape()), - ShapeUtil::HumanString(arguments[i]->on_host_shape())); + ShapeUtil::HumanStringWithLayout(arguments[i]->on_host_shape())); } } @@ -164,9 +164,8 @@ StatusOr LocalExecutable::Run( // ExecutableRunOptions.eigen_intra_op_thread_pool. // *) The thread pool used for XLA CPU ops is from // backend_->eigen_intra_op_thread_pool(). - ServiceExecutableRunOptions service_options( - run_options, backend_->StreamBorrower(), - backend_->eigen_intra_op_thread_pool()); + ServiceExecutableRunOptions service_options(run_options, + backend_->StreamBorrower()); if (executable_->dumping_snapshot()) { return ExecuteAndDump(&service_options, arguments); diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index ddb36680e8b185b053368baffa6f1d5cac50dc07..4f4fc8df31c633749ae9b6dafcdc38d4fd1eba40 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -114,7 +114,7 @@ class LocalClient : public Client { // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. // - // The given ExecutableBuildOptions override any values from TF_XLA_FLAGS + // The given ExecutableBuildOptions overrides any values from XLA_FLAGS // environment variable. StatusOr> Compile( const XlaComputation& computation, diff --git a/tensorflow/compiler/xla/client/sharding_builder.cc b/tensorflow/compiler/xla/client/sharding_builder.cc index fb9ea6ec3fc41d5e04ca125798a8199350470a44..b9bff06cbdbc3525eb19d5df885952c3971d9d6a 100644 --- a/tensorflow/compiler/xla/client/sharding_builder.cc +++ b/tensorflow/compiler/xla/client/sharding_builder.cc @@ -50,7 +50,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) { OpSharding result; result.set_type(OpSharding::Type::OpSharding_Type_OTHER); - CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); + CHECK_EQ(tile_shape.rank(), 1); std::vector dimensions(1, num_tiles); *result.mutable_tile_shape() = tile_shape.ToProto(); auto& tile_dimension = diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 60df2ec3959216b0564846ad47c21c5bcc01ea57..c00ba26295a30c192fedae48f5aabf78cbd7d831 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -29,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -192,9 +195,9 @@ StatusOr XlaBuilder::GetProgramShape(XlaOp root) const { } void XlaBuilder::IsConstantVisitor(const int64 op_handle, - std::set* visited, + absl::flat_hash_set* visited, bool* is_constant) const { - if (visited->count(op_handle) != 0 || !*is_constant) { + if (visited->contains(op_handle) || !*is_constant) { return; } @@ -208,11 +211,21 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, } // TODO(b/32495713): We aren't checking the called computations. break; + case HloOpcode::kGetDimensionSize: { + int64 dimension_number = instr.dimensions(0); + const HloInstructionProto& operand = + *(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie()); + Shape operand_shape(operand.shape()); + if (operand_shape.is_dynamic_dimension(dimension_number)) { + *is_constant = false; + } + break; + } // Non functional ops. case HloOpcode::kRng: - case HloOpcode::kCrossReplicaSum: - // TODO(b/33009255): Implmement constant folding for cross replica sum. + case HloOpcode::kAllReduce: + // TODO(b/33009255): Implement constant folding for cross replica sum. case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kCall: @@ -244,6 +257,29 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, int64 target_param_num, ShapeIndex target_param_index, int64 target_dim_num) { + bool param_exists = false; + for (HloInstructionProto& instr : instructions_) { + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && + instr.parameter_number() == target_param_num) { + param_exists = true; + Shape param_shape(instr.shape()); + Shape* param_shape_ptr = ¶m_shape; + for (int64 index : target_param_index) { + param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index); + } + param_shape_ptr->set_dynamic_dimension(target_dim_num, + /*is_dynamic=*/true); + *instr.mutable_shape() = param_shape.ToProto(); + } + } + + if (!param_exists) { + return InvalidArgument( + "Asked to mark parameter %lld as dynamic sized parameter, but the " + "doesn't exists", + target_param_num); + } + TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind( DynamicParameterBinding::DynamicParameter{dynamic_size_param_num, dynamic_size_param_index}, @@ -263,29 +299,52 @@ XlaComputation XlaBuilder::BuildAndNoteError() { return build_status.ConsumeValueOrDie(); } -StatusOr XlaBuilder::Build() { +StatusOr XlaBuilder::Build(bool remove_dynamic_dimensions) { if (!first_error_.ok()) { string backtrace; first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } - return Build(instructions_.back().id()); + return Build(instructions_.back().id(), remove_dynamic_dimensions); } -StatusOr XlaBuilder::Build(XlaOp root) { +StatusOr XlaBuilder::Build(XlaOp root, + bool remove_dynamic_dimensions) { if (root.builder_ != this) { return InvalidArgument("Given root operation is not in this computation."); } - return Build(root.handle()); + return Build(root.handle(), remove_dynamic_dimensions); } -StatusOr XlaBuilder::Build(int64 root_id) { +StatusOr XlaBuilder::Build(int64 root_id, + bool remove_dynamic_dimensions) { if (!first_error_.ok()) { string backtrace; first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } + // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove + // all dynamic dimensions before building xla program until we have support in + // the backend. + if (remove_dynamic_dimensions) { + std::function remove_dynamic_dimension = + [&](ShapeProto* shape) { + if (shape->tuple_shapes_size() != 0) { + for (int64 i = 0; i < shape->tuple_shapes_size(); ++i) { + remove_dynamic_dimension(shape->mutable_tuple_shapes(i)); + } + } + for (int64 i = 0; i < shape->dimensions_size(); ++i) { + shape->set_is_dynamic_dimension(i, false); + } + }; + + for (auto& instruction : instructions_) { + remove_dynamic_dimension(instruction.mutable_shape()); + } + } + HloComputationProto entry; SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId()); TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id)); @@ -310,7 +369,10 @@ StatusOr XlaBuilder::Build(int64 root_id) { module->add_computations()->Swap(&e.second); } module->add_computations()->Swap(&entry); - + if (!input_output_aliases_.empty()) { + TF_RETURN_IF_ERROR( + PopulateInputOutputAlias(module, program_shape, input_output_aliases_)); + } *(module->mutable_dynamic_parameter_binding()) = dynamic_parameter_binding_.ToProto(); @@ -323,6 +385,35 @@ StatusOr XlaBuilder::Build(int64 root_id) { return std::move(computation); } +/* static */ Status XlaBuilder::PopulateInputOutputAlias( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases) { + HloInputOutputAliasConfig config(program_shape.result()); + for (auto& alias : input_output_aliases) { + // The HloInputOutputAliasConfig does not do parameter validation as it only + // carries the result shape. Maybe it should be constructed with a + // ProgramShape to allow full validation. We will still get an error when + // trying to compile the HLO module, but would be better to have validation + // at this stage. + if (alias.param_number >= program_shape.parameters_size()) { + return InvalidArgument("Invalid parameter number %ld (total %ld)", + alias.param_number, + program_shape.parameters_size()); + } + const Shape& parameter_shape = program_shape.parameters(alias.param_number); + if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) { + return InvalidArgument("Invalid parameter %ld index: %s", + alias.param_number, + alias.param_index.ToString().c_str()); + } + TF_RETURN_IF_ERROR(config.SetUpAlias( + alias.output_index, alias.param_number, alias.param_index, + HloInputOutputAliasConfig::AliasKind::kUserAlias)); + } + *module->mutable_input_output_alias() = config.ToProto(); + return Status::OK(); +} + StatusOr XlaBuilder::InDimBroadcast( const Shape& shape, const XlaOp& operand, absl::Span broadcast_dimensions) { @@ -343,7 +434,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); CHECK(ShapeUtil::IsScalar(operand_shape) || - ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); + operand_shape.rank() == output_shape.rank()); Shape broadcast_shape = ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type()); @@ -355,7 +446,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; - for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) { + for (int i = 0; i < operand_shape.rank(); i++) { if (operand_shape.dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand_shape.dimensions(i)); @@ -398,8 +489,8 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, binop, lhs_shape, rhs_shape, broadcast_dimensions)); *instr.mutable_shape() = shape.ToProto(); - const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); - const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); + const int64 lhs_rank = lhs_shape.rank(); + const int64 rhs_rank = rhs_shape.rank(); XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; @@ -410,17 +501,19 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; std::vector to_size; - for (int64 size : shape.dimensions()) { - to_size.push_back(size); + std::vector to_size_is_dynamic; + for (int i = 0; i < shape.rank(); i++) { + to_size.push_back(shape.dimensions(i)); + to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i)); } - for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); - from_dim++) { + for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) { int64 to_dim = broadcast_dimensions[from_dim]; to_size[to_dim] = from_shape.dimensions(from_dim); + to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim); } - const Shape& broadcasted_shape = - ShapeUtil::MakeShape(from_shape.element_type(), to_size); + const Shape& broadcasted_shape = ShapeUtil::MakeShape( + from_shape.element_type(), to_size, to_size_is_dynamic); TF_ASSIGN_OR_RETURN( XlaOp broadcasted_operand, InDimBroadcast(broadcasted_shape, from, broadcast_dimensions)); @@ -458,18 +551,18 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; XlaOp updated_ehs = ehs; - if (!ShapeUtil::IsTuple(shape)) { - if (!ShapeUtil::IsTuple(lhs_shape) && + if (!shape.IsTuple()) { + if (!lhs_shape.IsTuple() && !ShapeUtil::SameDimensions(shape, lhs_shape)) { // lhs is being implicitly broadcasted. Change to explicit. TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, lhs)); } - if (!ShapeUtil::IsTuple(rhs_shape) && + if (!rhs_shape.IsTuple() && !ShapeUtil::SameDimensions(shape, rhs_shape)) { // rhs is being implicitly broadcasted. Change to explicit. TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, rhs)); } - if (!ShapeUtil::IsTuple(ehs_shape) && + if (!ehs_shape.IsTuple() && !ShapeUtil::SameDimensions(shape, ehs_shape)) { // ehs is being implicitly broadcasted. Change to explicit. TF_ASSIGN_OR_RETURN(updated_ehs, AddBroadcastSequence(shape, ehs)); @@ -563,10 +656,10 @@ XlaOp XlaBuilder::Broadcast(const XlaOp& operand, // output, so to append dimensions on the left the instruction's dimensions // should just be the n highest dimension numbers of the output shape where // n is the number of input dimensions. - const int64 operand_rank = ShapeUtil::Rank(operand_shape); + const int64 operand_rank = operand_shape.rank(); std::vector dimensions(operand_rank); for (int i = 0; i < operand_rank; ++i) { - dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank; + dimensions[i] = i + shape.rank() - operand_rank; } return InDimBroadcast(shape, operand, dimensions); }); @@ -579,8 +672,17 @@ XlaOp XlaBuilder::BroadcastInDim( TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); // Output shape, in the case of degenerate broadcast, the out_dim_size is // not necessarily the same as the dimension sizes of the output shape. - const auto& output_shape = + auto output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), out_dim_size); + for (int i = 0; i < broadcast_dimensions.size(); i++) { + if (broadcast_dimensions[i] < 0 || + broadcast_dimensions[i] > out_dim_size.size()) { + return InvalidArgument("Broadcast dimension %lld is out of bound", + broadcast_dimensions[i]); + } + output_shape.set_dynamic_dimension(broadcast_dimensions[i], + operand_shape.is_dynamic_dimension(i)); + } TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape( operand_shape, output_shape, broadcast_dimensions) @@ -639,10 +741,10 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - std::vector starts(ShapeUtil::Rank(shape), 0); + std::vector starts(shape.rank(), 0); std::vector limits(shape.dimensions().begin(), shape.dimensions().end()); - std::vector strides(ShapeUtil::Rank(shape), 1); + std::vector strides(shape.rank(), 1); starts[dimno] = start_index; limits[dimno] = limit_index; strides[dimno] = stride; @@ -660,7 +762,7 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, GetShape(start_indices)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( - operand_shape, start_indices_shape, slice_sizes)); + operand_shape, {start_indices_shape}, slice_sizes)); *instr.mutable_shape() = shape.ToProto(); for (int64 size : slice_sizes) { @@ -672,6 +774,34 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, }); } +XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + std::vector start_indices_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, + GetOperandShapes(start_indices)); + absl::c_transform(start_indices_shapes, + std::back_inserter(start_indices_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferDynamicSliceShape( + operand_shape, start_indices_shapes, slice_sizes)); + *instr.mutable_shape() = shape.ToProto(); + + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } + + std::vector operands = {operand}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); + }); +} + XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -681,13 +811,38 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferDynamicUpdateSliceShape( + operand_shape, update_shape, {start_indices_shape})); + *instr.mutable_shape() = shape.ToProto(); + + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + {operand, update, start_indices}); + }); +} + +XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); + std::vector start_indices_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, + GetOperandShapes(start_indices)); + absl::c_transform(start_indices_shapes, + std::back_inserter(start_indices_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicUpdateSliceShape( - operand_shape, update_shape, start_indices_shape)); + operand_shape, update_shape, start_indices_shapes)); *instr.mutable_shape() = shape.ToProto(); + std::vector operands = {operand, update}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, - {operand, update, start_indices}); + operands); }); } @@ -780,7 +935,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ","); std::vector new_sizes; - for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) { + for (int i = 0; i < original_shape.rank(); ++i) { if (i <= dimensions.front() || i > dimensions.back()) { new_sizes.push_back(original_shape.dimensions(i)); } else { @@ -808,10 +963,9 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& true_shape, GetShape(on_true)); TF_ASSIGN_OR_RETURN(const Shape& false_shape, GetShape(on_false)); - TF_RET_CHECK(ShapeUtil::IsTuple(true_shape) == - ShapeUtil::IsTuple(false_shape)); - HloOpcode opcode = ShapeUtil::IsTuple(true_shape) ? HloOpcode::kTupleSelect - : HloOpcode::kSelect; + TF_RET_CHECK(true_shape.IsTuple() == false_shape.IsTuple()); + HloOpcode opcode = + true_shape.IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect; return TernaryOp(opcode, pred, on_true, on_false); }); } @@ -835,7 +989,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data)); - if (!ShapeUtil::IsTuple(tuple_shape)) { + if (!tuple_shape.IsTuple()) { return InvalidArgument( "Operand to GetTupleElement() is not a tuple; got %s", ShapeUtil::HumanString(tuple_shape)); @@ -915,13 +1069,13 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, Status XlaBuilder::VerifyConvolution( const Shape& lhs_shape, const Shape& rhs_shape, const ConvolutionDimensionNumbers& dimension_numbers) const { - if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { + if (lhs_shape.rank() != rhs_shape.rank()) { return InvalidArgument( "Convolution arguments must have same number of " "dimensions. Got: %s and %s", ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } - int num_dims = ShapeUtil::Rank(lhs_shape); + int num_dims = lhs_shape.rank(); if (num_dims < 2) { return InvalidArgument( "Convolution expects argument arrays with >= 3 dimensions. " @@ -959,27 +1113,29 @@ Status XlaBuilder::VerifyConvolution( XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config); + feature_group_count, batch_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config); + feature_group_count, batch_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -1007,7 +1163,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); }); } @@ -1015,10 +1171,11 @@ XlaOp XlaBuilder::ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -1026,7 +1183,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( absl::Span> padding, absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -1045,14 +1203,15 @@ XlaOp XlaBuilder::ConvGeneralDilated( MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, feature_group_count, - instr.window(), dimension_numbers)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, feature_group_count, + batch_group_count, instr.window(), dimension_numbers)); *instr.mutable_shape() = shape.ToProto(); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); + instr.set_batch_group_count(batch_group_count); if (precision_config != nullptr) { *instr.mutable_precision_config() = *precision_config; @@ -1145,7 +1304,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); - if (ShapeUtil::IsArray(shape) && sharding() && + if (shape.IsArray() && sharding() && sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { // TODO(b/110793772): Support tiled array-shaped infeeds. return InvalidArgument( @@ -1221,7 +1380,7 @@ XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape, *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); - if (ShapeUtil::IsArray(shape) && sharding() && + if (shape.IsArray() && sharding() && sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { // TODO(b/110793772): Support tiled array-shaped infeeds. return InvalidArgument( @@ -1334,7 +1493,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span tokens) { for (int i = 0; i < tokens.size(); ++i) { const XlaOp& operand = tokens[i]; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - if (!ShapeUtil::IsToken(operand_shape)) { + if (!operand_shape.IsToken()) { return InvalidArgument( "All operands to AfterAll must be tokens; operand %d has shape %s", i, ShapeUtil::HumanString(operand_shape)); @@ -1577,7 +1736,7 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, *instr.mutable_shape() = shape.ToProto(); if (dimension == -1) { TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); - dimension = ShapeUtil::Rank(keys_shape) - 1; + dimension = keys_shape.rank() - 1; } instr.add_dimensions(dimension); std::vector operands{keys}; @@ -1647,12 +1806,12 @@ XlaOp XlaBuilder::Map(absl::Span operands, *instr.mutable_shape() = shape.ToProto(); Shape output_shape(instr.shape()); - const int64 output_rank = ShapeUtil::Rank(output_shape); + const int64 output_rank = output_shape.rank(); AddCalledComputation(computation, &instr); std::vector new_operands(operands.begin(), operands.end()); for (XlaOp& new_operand : new_operands) { TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand)); - const int64 rank = ShapeUtil::Rank(shape); + const int64 rank = shape.rank(); if (rank != output_rank) { TF_ASSIGN_OR_RETURN(new_operand, InDimBroadcast(output_shape, new_operand, {})); @@ -1861,7 +2020,7 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - std::vector all_dimnos(ShapeUtil::Rank(operand_shape)); + std::vector all_dimnos(operand_shape.rank()); std::iota(all_dimnos.begin(), all_dimnos.end(), 0); return Reduce(operand, init_value, computation, all_dimnos); }); @@ -2015,8 +2174,8 @@ XlaOp XlaBuilder::CrossReplicaSum( return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - {&operand_shape})); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferAllReduceShape({&operand_shape})); *instr.mutable_shape() = shape.ToProto(); for (const ReplicaGroup& group : replica_groups) { @@ -2029,8 +2188,7 @@ XlaOp XlaBuilder::CrossReplicaSum( AddCalledComputation(computation, &instr); - return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, - {operand}); + return AddInstruction(std::move(instr), HloOpcode::kAllReduce, {operand}); }); } @@ -2111,6 +2269,14 @@ XlaOp XlaBuilder::CollectivePermute( }); } +XlaOp XlaBuilder::ReplicaId() { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {}); + }); +} + XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, @@ -2288,7 +2454,7 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, ShapeUtil::HumanStringWithLayout(operand_shape)); } // TODO(b/111544877): Support tuple shapes. - if (!ShapeUtil::IsArray(operand_shape)) { + if (!operand_shape.IsArray()) { return InvalidArgument("SendToHost only supports array shapes, shape: %s", ShapeUtil::HumanString(operand_shape)); } @@ -2328,7 +2494,7 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, } // TODO(b/111544877): Support tuple shapes. - if (!ShapeUtil::IsArray(shape)) { + if (!shape.IsArray()) { return InvalidArgument( "RecvFromHost only supports array shapes, shape: %s", ShapeUtil::HumanString(shape)); @@ -2381,7 +2547,7 @@ StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { TF_RETURN_IF_ERROR(LookUpInstruction(operand).status()); bool is_constant = true; - std::set visited; + absl::flat_hash_set visited; IsConstantVisitor(operand.handle(), &visited, &is_constant); return is_constant; } @@ -2428,21 +2594,58 @@ StatusOr XlaBuilder::BuildConstantSubGraph( worklist.pop(); TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, LookUpInstructionByHandle(handle)); - for (int64 id : instr_proto->operand_ids()) { - if (related_ops.insert(id).second) { - worklist.push(id); + + if (instr_proto->opcode() == + HloOpcodeString(HloOpcode::kGetDimensionSize)) { + // At this point, BuildConstantSubGraph should never encounter a + // GetDimensionSize with a dynamic dimension. IsConstant check would have + // failed at the beginning of this function. + // + // Replace GetDimensionSize with a Constant representing the static bound + // of the shape. + int64 dimension = instr_proto->dimensions(0); + int64 operand_handle = instr_proto->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + LookUpInstructionByHandle(operand_handle)); + + TF_RET_CHECK(!operand_proto->shape().is_dynamic_dimension(dimension)); + auto constant_dimension_size = + static_cast(operand_proto->shape().dimensions(dimension)); + + Literal literal = LiteralUtil::CreateR0(constant_dimension_size); + + HloInstructionProto const_instr; + *const_instr.mutable_shape() = literal.shape().ToProto(); + *const_instr.mutable_literal() = literal.ToProto(); + *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); + + const_instr.set_id(handle); + *const_instr.mutable_name() = + GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id()); + *entry.add_instructions() = + const_instr; // Add to the result constant graph. + } else { + for (int64 id : instr_proto->operand_ids()) { + if (related_ops.insert(id).second) { + worklist.push(id); + } + } + for (int64 called_id : instr_proto->called_computation_ids()) { + related_calls.insert(called_id); } - } - for (int64 called_id : instr_proto->called_computation_ids()) { - related_calls.insert(called_id); } } // Add related ops to the computation. for (int64 id : related_ops) { - auto* instr = entry.add_instructions(); TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src, LookUpInstructionByHandle(id)); + + if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) { + continue; + } + auto* instr = entry.add_instructions(); + *instr = *instr_src; // Ensures that the instruction names are unique among the graph. const string& new_name = @@ -2715,12 +2918,21 @@ XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } +XlaOp DynamicSlice(const XlaOp& operand, absl::Span start_indices, + absl::Span slice_sizes) { + return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); +} XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); } +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices) { + return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); +} + XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension) { return builder->ConcatInDim(operands, dimension); @@ -2786,38 +2998,42 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return lhs.builder()->Conv(lhs, rhs, window_strides, padding, - feature_group_count, precision_config); + feature_group_count, batch_group_count, + precision_config); } XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvWithGeneralPadding( - lhs, rhs, window_strides, padding, feature_group_count, precision_config); + lhs, rhs, window_strides, padding, feature_group_count, batch_group_count, + precision_config); } XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config) { + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { return lhs.builder()->ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config); + batch_group_count, precision_config); } XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, @@ -2826,11 +3042,12 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers, feature_group_count, precision_config); + dimension_numbers, feature_group_count, batch_group_count, + precision_config); } XlaOp Fft(const XlaOp& operand, FftType fft_type, @@ -3010,6 +3227,8 @@ XlaOp CollectivePermute( return operand.builder()->CollectivePermute(operand, source_target_pairs); } +XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); } + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, Padding padding, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 098efb60f9bdca8306ff771a505f4a225dea9f7d..c429035ad0f96928525219a5506df81d64ffef95 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -197,11 +197,19 @@ class XlaBuilder { // status. Note that all ops that have been enqueued will be moved to the // computation being returned. The root of the computation will be the last // added operation. - StatusOr Build(); + // + // `remove_dynamic_dimensions` tells the builder whether to remove the + // dyanmic dimensions information in all ops. + // + // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the + // dynamic dimensions information when XLA backend can handle dynamic + // dimensions. + StatusOr Build(bool remove_dynamic_dimensions = true); // Overload of Build which specifies a particular root instruction for the // computation. - StatusOr Build(XlaOp root); + StatusOr Build(XlaOp root, + bool remove_dynamic_dimensions = true); // Builds the computation with the requested operations, or notes an error in // the parent XlaBuilder and returns an empty computation if building failed. @@ -269,6 +277,10 @@ class XlaBuilder { // and its real dynamic size is represented by `dynamic_param_index` in // parameter `dynamic_param_num`. // + // Note that this should be called before the dynamic parameters are used to + // create other operations, otherwise created operations won't have the + // dynamic dimensions information. + // // TODO(b/119520625): Remove this API once we have more dynamic shape infra // ready. Status SetDynamicBinding(int64 dynamic_size_param_num, @@ -276,9 +288,24 @@ class XlaBuilder { int64 target_param_num, ShapeIndex target_param_index, int64 target_dim_num); + // Adds a new input/output alias. Since the input/ouput shape information are + // not available until the computation is built, and eventual error in the + // arguments of this API will be detected only at computation Build() time. + void SetUpAlias(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + input_output_aliases_.push_back({output_index, param_number, param_index}); + } + private: + // Describes an input/output alias as inserted by the SetUpAlias() API. + struct InputOutputAlias { + ShapeIndex output_index; + int64 param_number; + ShapeIndex param_index; + }; + // Build helper which takes the id of the root operation.. - StatusOr Build(int64 root_id); + StatusOr Build(int64 root_id, bool remove_dynamic_dimensions); // Description for the methods below can be found in the corresponding public // functions section in this file. @@ -344,11 +371,18 @@ class XlaBuilder { XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); + ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); + XlaOp DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes); + ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); + XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); XlaOp ConcatInDim(absl::Span operands, int64 dimension); @@ -387,28 +421,28 @@ class XlaBuilder { XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, @@ -418,6 +452,7 @@ class XlaBuilder { absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, + int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); XlaOp Fft(const XlaOp& operand, FftType fft_type, @@ -527,6 +562,8 @@ class XlaBuilder { const XlaOp& operand, const std::vector>& source_target_pairs); + XlaOp ReplicaId(); + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, @@ -711,7 +748,8 @@ class XlaBuilder { // operation such as `RngNormal` or `Infeed`. The visitor walks the // computation starting at a given operation and sets is_constant to false iff // a parameter or stateful operation is encountered. - void IsConstantVisitor(const int64 op_handle, std::set* visited, + void IsConstantVisitor(const int64 op_handle, + absl::flat_hash_set* visited, bool* is_constant) const; // Checks bounds for convolution parameters. @@ -729,6 +767,12 @@ class XlaBuilder { int64 GetNextId() { return ++next_id_; } + // Populates the module with the input/output alias information stored within + // the input_output_aliases vector. + static Status PopulateInputOutputAlias( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases); + string name_; // Name to use for the built computation. // The next sequential ID for every instruction/computation contained within @@ -748,6 +792,9 @@ class XlaBuilder { // Dynamic parameter configuration of this computation. DynamicParameterBinding dynamic_parameter_binding_; + // Holds the input/output alias information populated by the SetUpAlias() API. + std::vector input_output_aliases_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. absl::flat_hash_map handle_to_index_; @@ -849,9 +896,14 @@ class XlaBuilder { friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); + friend XlaOp DynamicSlice(const XlaOp& operand, + absl::Span start_indices, + absl::Span slice_sizes); friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); + friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); friend XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension); @@ -881,23 +933,25 @@ class XlaBuilder { const PrecisionConfig* precision_config); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, const PrecisionConfig* precision_config); + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config); + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, + int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, @@ -906,7 +960,8 @@ class XlaBuilder { absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, const PrecisionConfig* precision_config); + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -987,6 +1042,7 @@ class XlaBuilder { friend XlaOp CollectivePermute( const XlaOp& operand, const std::vector>& source_target_pairs); + friend XlaOp ReplicaId(XlaBuilder* builder); friend XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, @@ -1290,10 +1346,15 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, // The size of the slice in each dimension is passed in 'slice_sizes', // which specify the end point of exclusive slice intervals in each // dimension [start, start + size). -// The shape of 'start_indices' must be rank == 1, with dimension size -// equal to the rank of the 'operand'. +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicSlice(const XlaOp& operand, absl::Span start_indices, + absl::Span slice_sizes); + +ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); @@ -1309,10 +1370,15 @@ XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] // [7 8 9] [7 8 9 ] // -// The shape of 'start_indices' must be rank == 1, with dimension size -// equal to the rank of the 'operand'. +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. // Slice index calculations are computed modulo update dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + absl::Span start_indices); + +ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); @@ -1372,7 +1438,7 @@ XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1381,6 +1447,7 @@ XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, int64 feature_group_count = 1, + int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1388,7 +1455,7 @@ XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1397,7 +1464,7 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, + int64 feature_group_count = 1, int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller @@ -1409,6 +1476,7 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, + int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and @@ -1515,9 +1583,33 @@ XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, XlaOp And(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); +// Overload to call And with 3 or more operands. We need the following somewhat +// convoluted overload set to disambiguate with the overload that takes the +// `broadcast_dimensions` optional param. +inline XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) { + return And(op1, And(op2, op3)); +} +template +XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3, + const XlaOpTs&... operands) { + return And(op1, And(op2, And(op3, operands...))); +} + XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); +// Overload to call Or with 3 or more operands. As with `And`, we need the +// following complicated overload set to handle the default arg in the `Or` +// overload above. +inline XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) { + return Or(op1, Or(op2, op3)); +} +template +XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3, + const XlaOpTs&... operands) { + return Or(op1, Or(op2, Or(op3, operands...))); +} + XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); @@ -1610,6 +1702,9 @@ XlaOp CollectivePermute( const XlaOp& operand, const std::vector>& source_target_pairs); +// Enqueues an operation that returns the replica ID. +XlaOp ReplicaId(XlaBuilder* builder); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -1681,10 +1776,14 @@ XlaOp Imag(const XlaOp& operand); XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); -// Enqueues an operator that tests if the operand's values are finite, i.e., -// not Inf or NaN. Defined only for floating-point types. Returns an array of -// booleans with the same shape where entries are true iff the corresponding -// entry was NaN. +// Enqueues an operator that tests if the operand's values are finite, i.e., not +// +/-Inf or NaN. Returns an array of booleans with the same shape where +// entries are true iff the corresponding entry was not infinite or NaN. +// +// Defined only for real-valued (i.e. not complex) floating-point types; raises +// an error for other types. +// +// See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h. XlaOp IsFinite(const XlaOp& operand); // Enqueues an iota operation onto the computation. diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index b3f5be300d3f15397ad33858a6a9cab5f6029688..c9fa738a19d0928d56ac4b98beb5fc0ed195518b 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -39,7 +40,8 @@ using ::testing::HasSubstr; class XlaBuilderTest : public ::testing::Test { protected: StatusOr> BuildHloModule(XlaBuilder* b) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, + b->Build(/*remove_dynamic_dimensions=*/false)); const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( @@ -50,7 +52,8 @@ class XlaBuilderTest : public ::testing::Test { // Overload which explicitly specifies the root instruction. StatusOr> BuildHloModule(XlaBuilder* b, XlaOp root) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root)); + TF_ASSIGN_OR_RETURN(XlaComputation computation, + b->Build(root, /*remove_dynamic_dimensions=*/false)); const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( @@ -132,6 +135,38 @@ TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { op::ShiftRightLogical(op::Constant(), op::Constant())); } +TEST_F(XlaBuilderTest, VariadicAnd) { + XlaBuilder b(TestName()); + Shape s = ShapeUtil::MakeShape(PRED, {}); + And(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"), + Parameter(&b, 2, s, "p2")); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + // Don't specify in the test whether And(x, y, z) is right- or + // left-associative; accept either one. + EXPECT_THAT( + module->entry_computation()->root_instruction(), + ::testing::AnyOf(op::And(op::Parameter(0), + op::And(op::Parameter(1), op::Parameter(2))), + op::And(op::And(op::Parameter(0), op::Parameter(1)), + op::Parameter(2)))); +} + +TEST_F(XlaBuilderTest, VariadicOr) { + XlaBuilder b(TestName()); + Shape s = ShapeUtil::MakeShape(PRED, {}); + Or(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"), + Parameter(&b, 2, s, "p2")); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + // Don't specify in the test whether Or(x, y, z) is right- or + // left-associative; accept either one. + EXPECT_THAT( + module->entry_computation()->root_instruction(), + ::testing::AnyOf( + op::Or(op::Parameter(0), op::Or(op::Parameter(1), op::Parameter(2))), + op::Or(op::Or(op::Parameter(0), op::Parameter(1)), + op::Parameter(2)))); +} + TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { XlaBuilder b(TestName()); ConstantR0(&b, 1) >> ConstantR0(&b, 2); @@ -446,6 +481,461 @@ TEST_F(XlaBuilderTest, ProtoMatches) { EXPECT_EQ(c0_string, c1_string); } +TEST_F(XlaBuilderTest, DynamicParameter) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + Parameter(&b, 1, ShapeUtil::MakeShape(U32, {}), "p1"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/1, + /*dynamic_size_param_index=*/{}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/p0)); + const Shape& param_shape = module->entry_computation() + ->parameter_instruction(0) + ->shape() + .tuple_shapes(1); + EXPECT_TRUE(param_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicUnary) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + Neg(gte); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicBinary) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {5}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Add(gte0, gte1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, DynamicBinaryHasBroadcast) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}), ShapeUtil::MakeShape(F32, {5}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Add(gte0, gte1, {0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicBroadcast) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + BroadcastInDim(gte, /*out_dim_size=*/{3, 5, 4}, + /*broadcast_dimensions=*/{1, 2}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicBinaryHasDegenerateBroadcast) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {10}), ShapeUtil::MakeShape(F32, {1, 15}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Add(gte0, gte1, /*broadcast_dimensions=*/{0}); // f32[<=10, 15] + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelectOnlyPredDynamic) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {10}), ShapeUtil::MakeShape(F32, {10}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + + Select(gte0, gte1, gte1); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicPad) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pad_val = ConstantR0(&b, -1); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + PaddingConfig padding_config; + for (int i = 0; i < 2; i++) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(0); + dimension->set_interior_padding(0); + } + Pad(gte, pad_val, padding_config); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicConvolution) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 2, 128}), + ShapeUtil::MakeShape(F32, {2, 2, 128, 8}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/2)); + auto input = GetTupleElement(p0, 0); + auto filter = GetTupleElement(p0, 1); + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), + {true, false, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicDot) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 3, 4}), + ShapeUtil::MakeShape(F32, {2, 4, 5}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + + auto lhs = GetTupleElement(p0, 0); + auto rhs = GetTupleElement(p0, 1); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + DotGeneral(lhs, rhs, dnums); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReduce) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5, 4, 3}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + auto gte = GetTupleElement(p0, 0); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + Reduce(gte, init, sum, {0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReduceWindow) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 4, 8}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0.f); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4}, + /*window_strides=*/{1, 1, 1}, Padding::kValid); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelectAndScatter) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 4, 8}), + ShapeUtil::MakeShape(F32, {2, 2, 2}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto init = ConstantR0(&b, 0.f); + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build()); + XlaBuilder bge(TestName()); + Ge(Parameter(&bge, 0, ShapeUtil::MakeShape(F32, {}), "x"), + Parameter(&bge, 1, ShapeUtil::MakeShape(F32, {}), "y")); + TF_ASSERT_OK_AND_ASSIGN(auto ge, bge.Build()); + + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/0)); + auto gte0 = GetTupleElement(p0, 0); + auto source = GetTupleElement(p0, 1); + SelectAndScatter(gte0, ge, {1, 2, 4}, {1, 2, 4}, Padding::kValid, source, + init, sum); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicReshape) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/2)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/3)); + auto gte = GetTupleElement(p0, 0); // f32[2, 3, <=4, <=5, 6] + Reshape(gte, /*new_sizes=*/{6, 4, 1, 5, 2, 3}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); + EXPECT_TRUE(result_shape.is_dynamic_dimension(3)); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), + {false, true, false, true, false, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelect) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {4, 5, 6}), + ShapeUtil::MakeShape(F32, {4, 5, 6}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/1)); + auto gte0 = GetTupleElement(p0, 0); + auto gte1 = GetTupleElement(p0, 1); + Select(pred, gte0, gte1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(result_shape.is_dynamic_dimension(1)); + EXPECT_FALSE(result_shape.is_dynamic_dimension(2)); + EXPECT_TRUE( + ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false})) + << result_shape; +} + +TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {4, 5, 6}), + ShapeUtil::MakeShape(F32, {4, 5, 6}), ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{2}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/1)); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{3}, + /*target_param_num=*/0, + /*target_param_index=*/{1}, + /*target_dim_num=*/2)); + auto gte0 = GetTupleElement(p0, 0); // f32[4,<=5,6] + auto gte1 = GetTupleElement(p0, 1); // f32[4,5,<=6] + Select(pred, gte0, gte1); + Status status = BuildHloModule(&b).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Operands to select must be the same shape; " + "got f32[4,<=5,6] and f32[4,5,<=6]")); +} + +TEST_F(XlaBuilderTest, DynamicTranspose) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 5}), ShapeUtil::MakeShape(U32, {})}); + auto p0 = Parameter(&b, 0, tuple_param_shape, "p0"); + ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0, + /*dynamic_size_param_index=*/{1}, + /*target_param_num=*/0, + /*target_param_index=*/{0}, + /*target_dim_num=*/0)); + auto gte = GetTupleElement(p0, 0); + Transpose(gte, /*permutation=*/{1, 0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const Shape& result_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {false, true})) + << result_shape; +} + TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { XlaBuilder b(TestName()); AfterAll(&b, {CreateToken(&b), ConstantR0(&b, 1.0)}); @@ -455,5 +945,31 @@ TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { ::testing::HasSubstr("All operands to AfterAll must be tokens")); } +TEST_F(XlaBuilderTest, CheckInputOutputAlias) { + XlaBuilder b(TestName()); + auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0"); + auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1"); + auto add = Add(p0, p1); + auto sub = Sub(p0, p1); + auto root = Tuple(&b, {add, sub}); + + b.SetUpAlias({1}, 0, {}); + b.SetUpAlias({0}, 1, {}); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root)); + + const HloInputOutputAliasConfig& config = module->input_output_alias_config(); + EXPECT_TRUE(config.ParameterHasAlias(0, {})); + EXPECT_TRUE(config.ParameterHasAlias(1, {})); + + auto alias_p0 = config.GetAliasedOutput(0, {}); + ASSERT_TRUE(alias_p0.has_value()); + EXPECT_EQ(*alias_p0, ShapeIndex({1})); + + auto alias_p1 = config.GetAliasedOutput(1, {}); + ASSERT_TRUE(alias_p1.has_value()); + EXPECT_EQ(*alias_p1, ShapeIndex({0})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index a40330a9b1fe201b6ec83d1bfe1a21e294e18f55..a9a91648ac377987e7f226116e11c9c697ace103 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -22,49 +22,49 @@ limitations under the License. #include "tensorflow/compiler/xla/parse_flags_from_env.h" namespace xla { -namespace { -DebugOptions* flag_values; -std::vector* flag_objects; -std::once_flag flags_init; - -void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_llvm_enable_alias_scope_metadata(true); - flags->set_xla_llvm_enable_noalias_metadata(true); - flags->set_xla_llvm_enable_invariant_load_metadata(true); - flags->set_xla_llvm_disable_expensive_passes(false); - flags->set_xla_backend_optimization_level(3); - flags->set_xla_cpu_multi_thread_eigen(true); - flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); - flags->set_xla_eliminate_hlo_implicit_broadcast(true); +DebugOptions DefaultDebugOptionsIgnoringFlags() { + DebugOptions opts; + opts.set_xla_llvm_enable_alias_scope_metadata(true); + opts.set_xla_llvm_enable_noalias_metadata(true); + opts.set_xla_llvm_enable_invariant_load_metadata(true); + opts.set_xla_llvm_disable_expensive_passes(false); + opts.set_xla_backend_optimization_level(3); + opts.set_xla_cpu_multi_thread_eigen(true); + opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); + opts.set_xla_eliminate_hlo_implicit_broadcast(true); + opts.set_xla_hlo_dump_as_html(false); #ifdef INTEL_MKL - flags->set_xla_cpu_use_mkl_dnn(true); + opts.set_xla_cpu_use_mkl_dnn(true); #endif // INTEL_MKL - flags->set_xla_gpu_max_kernel_unroll_factor(4); + opts.set_xla_gpu_max_kernel_unroll_factor(4); // Set cudnn batchnorm off by default; it does not provide a performance win // on average. - flags->set_xla_gpu_use_cudnn_batchnorm(false); + opts.set_xla_gpu_use_cudnn_batchnorm(false); // Run all GPU work on one stream by default. Using multiple streams // increases memory usage and we lack strong motivating benchmarks for tuning // the heuristics needed to decide when to run on multiple streams. See // b/77879207. - flags->set_xla_gpu_disable_multi_streaming(true); + opts.set_xla_gpu_disable_multi_streaming(true); // TODO(jlebar): Disable fastmath once doing so is not a performance // regression. - flags->set_xla_cpu_enable_fast_math(true); - flags->set_xla_gpu_enable_fast_math(true); + opts.set_xla_cpu_enable_fast_math(true); + opts.set_xla_gpu_enable_fast_min_max(true); - flags->set_xla_force_host_platform_device_count(1); + opts.set_xla_force_host_platform_device_count(1); + return opts; } +static DebugOptions* flag_values; +static std::vector* flag_objects; +static std::once_flag flags_init; + // Allocates flag_values and flag_objects; this function must not be called more // than once - its call done via call_once. -void AllocateFlags() { - flag_values = new DebugOptions; - - SetDebugOptionsDefaults(flag_values); +static void AllocateFlags() { + flag_values = new DebugOptions(DefaultDebugOptionsIgnoringFlags()); // Returns a lambda that calls "member_setter" on "flag_values" with the // argument passed in to the lambda. @@ -133,6 +133,11 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), flag_values->xla_hlo_dump_as_graphdef(), "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag("xla_hlo_dump_as_html", + bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_html), + flag_values->xla_hlo_dump_as_html(), + "Dump HLO graphs as an HTML (DOT rendered into SVG " + "inlined in HTML)."), tensorflow::Flag( "xla_hlo_graph_sharding_color", bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), @@ -160,11 +165,11 @@ void AllocateFlags() { "Enable unsafe fast-math optimizations in the CPU compiler; " "this may produce faster code at the expense of some accuracy."), tensorflow::Flag( - "xla_gpu_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), - flag_values->xla_cpu_enable_fast_math(), - "Enable unsafe fast-math optimizations in the GPU compiler; " - "this may produce faster code at the expense of some accuracy."), + "xla_gpu_enable_fast_min_max", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), + flag_values->xla_gpu_enable_fast_min_max(), + "Enable fast floating point min/max lowering that does not propagate " + "NaNs."), tensorflow::Flag( "xla_llvm_enable_alias_scope_metadata", bool_setter_for( @@ -202,6 +207,16 @@ void AllocateFlags() { "Comma-separated list of hlo passes to be disabled. These names " "must exactly match the passes' names; no whitespace around " "commas."), + tensorflow::Flag( + "xla_disable_all_hlo_passes", + bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false, + "Disables all HLO passes. Notes that some passes are necessary for " + "correctness and the invariants that must be satisfied by 'fully " + "optimized' HLO are different for different devices and may change " + "over time. The only 'guarantee', such as it is, is that if you " + "compile XLA and dump the optimized HLO for some graph, you should " + "be able to run it again on the same device with the same build of " + "XLA."), tensorflow::Flag( "xla_embed_ir_in_executable", bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), @@ -334,12 +349,16 @@ void AllocateFlags() { "overhead from context switching but we let the user override this " "behavior to help run tests on the host that run models in parallel " "across multiple devices."), + tensorflow::Flag( + "xla_gpu_disable_ptxas_optimizations", + bool_setter_for( + &DebugOptions::set_xla_gpu_disable_ptxas_optimizations), + flag_values->xla_gpu_disable_ptxas_optimizations(), + "In XLA:GPU run ptxas in -O0 (default is -O3)."), }); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } -} // namespace - void AppendDebugOptionsFlags(std::vector* flag_list) { std::call_once(flags_init, &AllocateFlags); flag_list->insert(flag_list->end(), flag_objects->begin(), diff --git a/tensorflow/compiler/xla/debug_options_flags.h b/tensorflow/compiler/xla/debug_options_flags.h index 60e59abc2a2e0f1cce3de1afc928f9fe36f75b33..dbf86a40f052af09c61da0e1abb3116ef5214357 100644 --- a/tensorflow/compiler/xla/debug_options_flags.h +++ b/tensorflow/compiler/xla/debug_options_flags.h @@ -29,7 +29,10 @@ void AppendDebugOptionsFlags(std::vector* flag_list); // Fetches a DebugOptions proto message from flags provided to the program. // Flags must be registered with the flags parser using AppendDebugOptionsFlags // first. -xla::DebugOptions GetDebugOptionsFromFlags(); +DebugOptions GetDebugOptionsFromFlags(); + +// Gets a DebugOptions proto that reflects the defaults as if no flags were set. +DebugOptions DefaultDebugOptionsIgnoringFlags(); } // namespace xla diff --git a/tensorflow/compiler/xla/error_spec.h b/tensorflow/compiler/xla/error_spec.h index a1463aa15941b9c265db94e2eb3cc176fab6695b..4359f3b7deb8e585494cb2a9c7115eac6a312c8e 100644 --- a/tensorflow/compiler/xla/error_spec.h +++ b/tensorflow/compiler/xla/error_spec.h @@ -30,6 +30,19 @@ struct ErrorSpec { // In effect, this allows the tested operation to produce incorrect results // for inputs outside its mathematical domain. bool relaxed_nans; + + // If this is true, then we treat each +/-inf in the actual result as + // equivalent to our choice of either +/-inf or the min/max floating-point + // value. + // + // If the expected result is +/-inf, the actual result must still be +/-inf. + // + // In effect, this allows the tested operation to overflow, so long as it's + // overflowing on "large" values. + // + // (We could have a symmetric more_infs_ok flag if necessary; right now it + // appears not to be.) + bool fewer_infs_ok = false; }; } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index ba3217f31b55bd1428f67da6154a46c8bc304053..6f36d11dfb34eb27e79ea4ff797d35f80fb44b27 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -16,9 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ -// Pulls in the ::stream_executor -> ::xla::se namespace alias. -#include "tensorflow/compiler/xla/types.h" - // These classes are forward declared so that ExecutableRunOptions can be linked // into an XLA-compiled binary without having to link all of the pointed-to // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't @@ -28,12 +25,6 @@ class Stream; class Platform; } // namespace stream_executor -namespace tensorflow { -namespace thread { -class ThreadPool; -} // namespace thread -} // namespace tensorflow - namespace Eigen { struct ThreadPoolDevice; } // namespace Eigen diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index 1fea816a803bfb75b9721393cef8c4dfc249268d..c34e84efc80ba970624d80802841d6ec534b6fd0 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -104,9 +104,9 @@ class Sharding(object): ValueError: The tensor to split was smaller in the split dimension than the number of devices to split over. """ - tensor.shape.assert_is_fully_defined() shape = tensor.shape.as_list() - if shape[split_dimension] < num_devices: + if (shape[split_dimension] is not None and + shape[split_dimension] < num_devices): raise ValueError('Split dimension was smaller than the required number ' 'of splits: shape=%r, dimension=%r, num_devices=%r' % (shape, split_dimension, num_devices)) diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index 12b7094705e75305dc43a013576f4549dd5f4185..267701e9c0e42a21d2cda6238520f6a9692e7e76 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -31,3 +31,5 @@ upper_tabs: - title: XLA compile API path: /xla/tutorials/xla_compile status: experimental + +- include: /_upper_tabs_right.yaml diff --git a/tensorflow/compiler/xla/g3doc/broadcasting.md b/tensorflow/compiler/xla/g3doc/broadcasting.md index 2870869a2cef13a9105b9dc9fa4d657834288f86..5c0525c1e9adf9f37d945170d05e7c18fa3d8852 100644 --- a/tensorflow/compiler/xla/g3doc/broadcasting.md +++ b/tensorflow/compiler/xla/g3doc/broadcasting.md @@ -168,7 +168,7 @@ consult the Broadcasting of a lower-rank array to a higher-rank array **and** broadcasting using degenerate dimensions can both be performed in the same binary operation. -For example, a vector of size 4 and an matrix of size 1x2 can be added together +For example, a vector of size 4 and a matrix of size 1x2 can be added together using broadcast dimensions value of (0): |1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector. @@ -176,7 +176,7 @@ using broadcast dimensions value of (0): First the vector is broadcast up to rank 2 (matrix) using the broadcast dimensions. The single value (0) in the broadcast dimensions indicates that dimension zero of the vector matches to dimension zero of the matrix. This -produces an matrix of size 4xM where the value M is chosen to match the +produces a matrix of size 4xM where the value M is chosen to match the corresponding dimension size in the 1x2 array. Therefore, a 4x2 matrix is produced: diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index d888b1f23f36f33ef94ef0e22374e0c796e47a89..363fd17b69bfbe54d486e367d9bf5cc0eee4205e 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -38,25 +38,25 @@ Alltoall is a collective operation that sends data from all cores to all cores. It has two phases: 1. the scatter phase. On each core, the operand is split into `split_count` - number of blocks along the `split_dimensions`, and the blocks are scattered - to all cores, e.g., the ith block is send to the ith core. +number of blocks along the `split_dimensions`, and the blocks are scattered +to all cores, e.g., the ith block is send to the ith core. 2. the gather phase. Each core concatenates the received blocks along the - `concat_dimension`. +`concat_dimension`. The participating cores can be configured by: - `replica_groups`: each ReplicaGroup contains a list of replica id. If empty, - all replicas belong to one group in the order of 0 - (n-1). Alltoall will be - applied within subgroups in the specified order. For example, replica - groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica - 1, 2, 3, and in the gather phase, the received blocks will be concatenated - in the order of 1, 2, 3; another Alltoall will be applied within replica 4, - 5, 0, and the concatenation order is 4, 5, 0. +all replicas belong to one group in the order of 0 - (n-1). Alltoall will be +applied within subgroups in the specified order. For example, replica +groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica +1, 2, 3, and in the gather phase, the received blocks will be concatenated +in the order of 1, 2, 3; another Alltoall will be applied within replica 4, +5, 0, and the concatenation order is 4, 5, 0. Prerequisites: - The dimension size of the operand on the split_dimension is divisible by - split_count. +split_count. - The operand's shape is not tuple. `AllToAll(operand, split_dimension, concat_dimension, split_count, @@ -93,7 +93,7 @@ AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); ```
- +
In this example, there are 4 cores participating the Alltoall. On each core, the @@ -387,34 +387,34 @@ For example, let v be an array of 24 elements: ``` let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}}, - {{20, 21, 22}, {25, 26, 27}}, - {{30, 31, 32}, {35, 36, 37}}, - {{40, 41, 42}, {45, 46, 47}}}; +{{20, 21, 22}, {25, 26, 27}}, +{{30, 31, 32}, {35, 36, 37}}, +{{40, 41, 42}, {45, 46, 47}}}; // Collapse to a single dimension, leaving one dimension. let v012 = Collapse(v, {0,1,2}); then v012 == f32[24] {10, 11, 12, 15, 16, 17, - 20, 21, 22, 25, 26, 27, - 30, 31, 32, 35, 36, 37, - 40, 41, 42, 45, 46, 47}; +20, 21, 22, 25, 26, 27, +30, 31, 32, 35, 36, 37, +40, 41, 42, 45, 46, 47}; // Collapse the two lower dimensions, leaving two dimensions. let v01 = Collapse(v, {0,1}); then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17}, - {20, 21, 22, 25, 26, 27}, - {30, 31, 32, 35, 36, 37}, - {40, 41, 42, 45, 46, 47}}; +{20, 21, 22, 25, 26, 27}, +{30, 31, 32, 35, 36, 37}, +{40, 41, 42, 45, 46, 47}}; // Collapse the two higher dimensions, leaving two dimensions. let v12 = Collapse(v, {1,2}); then v12 == f32[8x3] {{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}; +{15, 16, 17}, +{20, 21, 22}, +{25, 26, 27}, +{30, 31, 32}, +{35, 36, 37}, +{40, 41, 42}, +{45, 46, 47}}; ``` @@ -441,9 +441,9 @@ replicas. Note that there are the following restrictions on the `source_target_pair`: - Any two pairs should not have the same target replica id, and they should - not have the same source replica id. +not have the same source replica id. - If a replica id is not a target in any pair, then the output on that replica - is a tensor consists of 0(s) with the same shape as the input. +is a tensor consists of 0(s) with the same shape as the input. ## Concatenate @@ -480,25 +480,25 @@ Concat({{2, 3}, {4, 5}, {6, 7}}, 0) ``` let a = { - {1, 2}, - {3, 4}, - {5, 6}, +{1, 2}, +{3, 4}, +{5, 6}, }; let b = { - {7, 8}, +{7, 8}, }; Concat({a, b}, 0) >>> { - {1, 2}, - {3, 4}, - {5, 6}, - {7, 8}, +{1, 2}, +{3, 4}, +{5, 6}, +{7, 8}, } ``` Diagram:
- +
## Conditional @@ -548,17 +548,23 @@ Computes a convolution of the kind used in neural networks. Here, a convolution can be thought of as a n-dimensional window moving across a n-dimensional base area and a computation is performed for each possible position of the window. -| Arguments | Type | Semantics | -| --------------------- | -------------------- | ----------------------------- | -| `lhs` | `XlaOp` | rank n+2 array of inputs | -| `rhs` | `XlaOp` | rank n+2 array of kernel | -: : : weights : -| `window_strides` | `ArraySlice` | n-d array of kernel strides | -| `padding` | `ArraySlice< | n-d array of (low, high) | -: : pair>` : padding : -| `lhs_dilation` | `ArraySlice` | n-d lhs dilation factor array | -| `rhs_dilation` | `ArraySlice` | n-d rhs dilation factor array | -| `feature_group_count` | int64 | the number of feature groups | +| Arguments | Type | Semantics | +| --------------------- | ------------------------ | ------------------------ | +| `lhs` | `XlaOp` | rank n+2 array of inputs | +| `rhs` | `XlaOp` | rank n+2 array of kernel | +: : : weights : +| `window_strides` | `ArraySlice` | n-d array of kernel | +: : : strides : +| `padding` | `ArraySlice< pair>` : padding : +| `lhs_dilation` | `ArraySlice` | n-d lhs dilation factor | +: : : array : +| `rhs_dilation` | `ArraySlice` | n-d rhs dilation factor | +: : : array : +| `feature_group_count` | int64 | the number of feature | +: : : groups : +| `batch_group_count` | int64 | the number of batch | +: : : groups : Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2 array describing the base area. This is called the input, even though of course @@ -566,20 +572,20 @@ the rhs is also an input. In a neural network, these are the input activations. The n+2 dimensions are, in this order: * `batch`: Each coordinate in this dimension represents an independent input - for which convolution is carried out. +for which convolution is carried out. * `z/depth/features`: Each (y,x) position in the base area has a vector - associated to it, which goes into this dimension. +associated to it, which goes into this dimension. * `spatial_dims`: Describes the `n` spatial dimensions that define the base - area that the window moves across. +area that the window moves across. The `rhs` argument is a rank n+2 array describing the convolutional filter/kernel/window. The dimensions are, in this order: * `output-z`: The `z` dimension of the output. * `input-z`: The size of this dimension times `feature_group_count` should - equal the size of the `z` dimension in lhs. +equal the size of the `z` dimension in lhs. * `spatial_dims`: Describes the `n` spatial dimensions that define the n-d - window that moves across the base area. +window that moves across the base area. The `window_strides` argument specifies the stride of the convolutional window in the spatial dimensions. For example, if the stride in the first spatial @@ -628,9 +634,22 @@ input feature dimension, and the filter would be reshaped from `[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more details, see `tf.nn.depthwise_conv2d`. +The `batch_group_count` (default value 1) argument can be used for depthwise +filters during backpropagation. `batch_group_count` needs to be a divisor of the +size of the `lhs` (input) batch dimension. If `batch_group_count` is greater +than 1, it means that the output batch dimension should be of size +`batch_group_size` where `batch_group_size = input batch / batch_group_count`. +For convolutions with `batch_group_count` greater than 1, the input batch size +must evenly divide into batch_group_size and output feature size, which implies +that the output feature size must be equal to batch_group_count. Conceptually, +this can be achieved by performing the usual convolution, and then scraping +`batch_group_size` number of elements on the diagonal of the matrix formed by +output batch and output feature. + The output shape has these dimensions, in this order: -* `batch`: Same size as `batch` on the input (`lhs`). +* `batch`: The size of this dimension times `batch_group_count` should equal + the size of the `batch` dimension in lhs. * `z`: Same size as `output-z` on the kernel (`rhs`). * `spatial_dims`: One value for each valid placement of the convolutional window. @@ -658,15 +677,15 @@ Here is pseudo-code for a 2d convolution with padding and striding: ``` for (b, oz, oy, ox) { // output coordinates - value = 0; - for (iz, ky, kx) { // kernel coordinates and input z - iy = oy*stride_y + ky - pad_low_y; - ix = ox*stride_x + kx - pad_low_x; - if ((iy, ix) inside the base area considered without padding) { - value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); - } - } - output(b, oz, oy, ox) = value; +value = 0; +for (iz, ky, kx) { // kernel coordinates and input z +iy = oy*stride_y + ky - pad_low_y; +ix = ox*stride_x + kx - pad_low_x; +if ((iy, ix) inside the base area considered without padding) { +value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); +} +} +output(b, oz, oy, ox) = value; } ``` @@ -777,19 +796,19 @@ Here is an example of an implementation of `myfunc`: ``` extern "C" void myfunc(void* out, void** in) { - float (&x)[2] = *static_cast(in[0]); - float (&y)[2][3] = *static_cast(in[1]); - EXPECT_EQ(1, x[0]); - EXPECT_EQ(2, x[1]); - EXPECT_EQ(10, y[0][0]); - EXPECT_EQ(20, y[0][1]); - EXPECT_EQ(30, y[0][2]); - EXPECT_EQ(40, y[1][0]); - EXPECT_EQ(50, y[1][1]); - EXPECT_EQ(60, y[1][2]); - float (&z)[3][3] = *static_cast(out); - z[0][0] = x[1] + y[1][0]; - // ... +float (&x)[2] = *static_cast(in[0]); +float (&y)[2][3] = *static_cast(in[1]); +EXPECT_EQ(1, x[0]); +EXPECT_EQ(2, x[1]); +EXPECT_EQ(10, y[0][0]); +EXPECT_EQ(20, y[0][1]); +EXPECT_EQ(30, y[0][2]); +EXPECT_EQ(40, y[1][0]); +EXPECT_EQ(50, y[1][1]); +EXPECT_EQ(60, y[1][2]); +float (&z)[3][3] = *static_cast(out); +z[0][0] = x[1] + y[1][0]; +// ... } ``` @@ -856,44 +875,40 @@ DotGeneral performs the sum of products over contracting dimensions specified in 'dimension_numbers'. Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need -to be the same, but must be listed in the same order in both -'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes. -There must be exactly one contracting dimension on both 'lhs' and 'rhs'. +to be the same and but must have the same dimension sizes. Example with contracting dimension numbers: ``` lhs = { {1.0, 2.0, 3.0}, - {4.0, 5.0, 6.0} } +{4.0, 5.0, 6.0} } rhs = { {1.0, 1.0, 1.0}, - {2.0, 2.0, 2.0} } +{2.0, 2.0, 2.0} } DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(1); dnums.add_rhs_contracting_dimensions(1); DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0}, - {15.0, 30.0} } +{15.0, 30.0} } ``` -Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same -dimension number, must be listed in the same order in both arrays, must -have the same dimension sizes, and must be ordered before contracting and -non-contracting/non-batch dimension numbers. +Associated batch dimension numbers from the 'lhs' and 'rhs' must +have the same dimension sizes. Example with batch dimension numbers (batch size 2, 2x2 matrices): ``` lhs = { { {1.0, 2.0}, - {3.0, 4.0} }, - { {5.0, 6.0}, - {7.0, 8.0} } } +{3.0, 4.0} }, +{ {5.0, 6.0}, +{7.0, 8.0} } } rhs = { { {1.0, 0.0}, - {0.0, 1.0} }, - { {1.0, 0.0}, - {0.0, 1.0} } } +{0.0, 1.0} }, +{ {1.0, 0.0}, +{0.0, 1.0} } } DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(2); @@ -902,9 +917,9 @@ dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0}, - {3.0, 4.0} }, - { {5.0, 6.0}, - {7.0, 8.0} } } +{3.0, 4.0} }, +{ {5.0, 6.0}, +{7.0, 8.0} } } ``` | Input | Output | Semantics | @@ -929,21 +944,21 @@ dimension: [start, start + size). The shape of `start_indices` must be rank == `DynamicSlice(operand, start_indices, size_indices)` -| Arguments | Type | Semantics | -| --------------- | ------------------- | ----------------------------------- | -| `operand` | `XlaOp` | N dimensional array of type T | -| `start_indices` | `XlaOp` | Rank 1 array of N integers | -: : : containing the starting indices of : -: : : the slice for each dimension. Value : -: : : must be greater than or equal to : -: : : zero. : -| `size_indices` | `ArraySlice` | List of N integers containing the | -: : : slice size for each dimension. Each : -: : : value must be strictly greater than : -: : : zero, and start + size must be less : -: : : than or equal to the size of the : -: : : dimension to avoid wrapping modulo : -: : : dimension size. : +| Arguments | Type | Semantics | +| --------------- | --------------------- | ---------------------------------- | +| `operand` | `XlaOp` | N dimensional array of type T | +| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | +: : : containing the starting indices of : +: : : the slice for each dimension. : +: : : Value must be greater than or : +: : : equal to zero. : +| `size_indices` | `ArraySlice` | List of N integers containing the | +: : : slice size for each dimension. : +: : : Each value must be strictly : +: : : greater than zero, and start + : +: : : size must be less than or equal to : +: : : the size of the dimension to avoid : +: : : wrapping modulo dimension size. : The effective slice indices are computed by applying the following transformation for each index `i` in `[1, N)` before performing the slice: @@ -963,22 +978,22 @@ let a = {0.0, 1.0, 2.0, 3.0, 4.0} let s = {2} DynamicSlice(a, s, {2}) produces: - {2.0, 3.0} +{2.0, 3.0} ``` 2-dimensional example: ``` let b = - { {0.0, 1.0, 2.0}, - {3.0, 4.0, 5.0}, - {6.0, 7.0, 8.0}, - {9.0, 10.0, 11.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 4.0, 5.0}, +{6.0, 7.0, 8.0}, +{9.0, 10.0, 11.0} } let s = {2, 1} DynamicSlice(b, s, {2, 2}) produces: - { { 7.0, 8.0}, - {10.0, 11.0} } +{ { 7.0, 8.0}, +{10.0, 11.0} } ``` ## DynamicUpdateSlice @@ -994,19 +1009,22 @@ the rank of `operand`. `DynamicUpdateSlice(operand, update, start_indices)` -| Arguments | Type | Semantics | -| --------------- | ------- | ------------------------------------------------ | -| `operand` | `XlaOp` | N dimensional array of type T | -| `update` | `XlaOp` | N dimensional array of type T containing the | -: : : slice update. Each dimension of update shape : -: : : must be strictly greater than zero, and start + : -: : : update must be less than or equal to the operand : -: : : size for each dimension to avoid generating : -: : : out-of-bounds update indices. : -| `start_indices` | `XlaOp` | Rank 1 array of N integers containing the | -: : : starting indices of the slice for each : -: : : dimension. Value must be greater than or equal : -: : : to zero. : +| Arguments | Type | Semantics | +| --------------- | --------------------- | ---------------------------------- | +| `operand` | `XlaOp` | N dimensional array of type T | +| `update` | `XlaOp` | N dimensional array of type T | +: : : containing the slice update. Each : +: : : dimension of update shape must be : +: : : strictly greater than zero, and : +: : : start + update must be less than : +: : : or equal to the operand size for : +: : : each dimension to avoid generating : +: : : out-of-bounds update indices. : +| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | +: : : containing the starting indices of : +: : : the slice for each dimension. : +: : : Value must be greater than or : +: : : equal to zero. : The effective slice indices are computed by applying the following transformation for each index `i` in `[1, N)` before performing the slice: @@ -1027,29 +1045,29 @@ let u = {5.0, 6.0} let s = {2} DynamicUpdateSlice(a, u, s) produces: - {0.0, 1.0, 5.0, 6.0, 4.0} +{0.0, 1.0, 5.0, 6.0, 4.0} ``` 2-dimensional example: ``` let b = - { {0.0, 1.0, 2.0}, - {3.0, 4.0, 5.0}, - {6.0, 7.0, 8.0}, - {9.0, 10.0, 11.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 4.0, 5.0}, +{6.0, 7.0, 8.0}, +{9.0, 10.0, 11.0} } let u = - { {12.0, 13.0}, - {14.0, 15.0}, - {16.0, 17.0} } +{ {12.0, 13.0}, +{14.0, 15.0}, +{16.0, 17.0} } let s = {1, 1} DynamicUpdateSlice(b, u, s) produces: - { {0.0, 1.0, 2.0}, - {3.0, 12.0, 13.0}, - {6.0, 14.0, 15.0}, - {9.0, 16.0, 17.0} } +{ {0.0, 1.0, 2.0}, +{3.0, 12.0, 13.0}, +{6.0, 14.0, 15.0}, +{9.0, 16.0, 17.0} } ``` ## Element-wise binary arithmetic operations @@ -1080,7 +1098,7 @@ When `Op` is `Rem`, the sign of the result is taken from the dividend, and the absolute value of the result is always less than the divisor's absolute value. Integer division overflow (signed/unsigned division/remainder by zero or signed -divison/remainder of `INT_SMIN` with `-1`) produces an implementation defined +division/remainder of `INT_SMIN` with `-1`) produces an implementation defined value. An alternative variant with different-rank broadcasting support exists for these @@ -1235,42 +1253,42 @@ shape of `start_indices` to be `[6,7,1]`). The bounds for the output array along dimension `i` is computed as follows: - 1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for - some `k`) then we pick the corresponding dimension bounds out of - `start_indices.shape`, skipping `index_vector_dim` (i.e. pick - `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and - `start_indices.shape.dims`[`k`+`1`] otherwise). +1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for +some `k`) then we pick the corresponding dimension bounds out of +`start_indices.shape`, skipping `index_vector_dim` (i.e. pick +`start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and +`start_indices.shape.dims`[`k`+`1`] otherwise). - 2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for - some `k`) then we pick the corresponding bound out of `slice_sizes` after - accounting for `collapsed_slice_dims` (i.e. we pick - `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` - with the bounds at indices `collapsed_slice_dims` removed). +2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for +some `k`) then we pick the corresponding bound out of `slice_sizes` after +accounting for `collapsed_slice_dims` (i.e. we pick +`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` +with the bounds at indices `collapsed_slice_dims` removed). Formally, the operand index `In` corresponding to an output index `Out` is computed as follows: - 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out - vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where - Combine(A, b) inserts b at position `index_vector_dim` into A. Note that - this is well defined even if `G` is empty -- if `G` is empty then `S` = - `start_indices`. - - 2. Create a starting index, `S``in`, into `operand` using `S` by - scattering `S` using `start_index_map`. More precisely: - 1. `S``in`[`start_index_map`[`k`]] = `S`[`k`] if `k` < - `start_index_map.size`. - 2. `S``in`[`_`] = `0` otherwise. - - 3. Create an index `O``in` into `operand` by scattering the indices - at the offset dimensions in `Out` according to the `collapsed_slice_dims` - set. More precisely: - 1. `O``in`[`expand_offset_dims`(`k`)] = - `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` - (`expand_offset_dims` is defined below). - 2. `O``in`[`_`] = `0` otherwise. - 4. `In` is `O``in` + `S``in` where + is element-wise - addition. +1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out +vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where +Combine(A, b) inserts b at position `index_vector_dim` into A. Note that +this is well defined even if `G` is empty -- if `G` is empty then `S` = +`start_indices`. + +2. Create a starting index, `S``in`, into `operand` using `S` by +scattering `S` using `start_index_map`. More precisely: +1. `S``in`[`start_index_map`[`k`]] = `S`[`k`] if `k` < +`start_index_map.size`. +2. `S``in`[`_`] = `0` otherwise. + +3. Create an index `O``in` into `operand` by scattering the indices +at the offset dimensions in `Out` according to the `collapsed_slice_dims` +set. More precisely: +1. `O``in`[`expand_offset_dims`(`k`)] = +`Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` +(`expand_offset_dims` is defined below). +2. `O``in`[`_`] = `0` otherwise. +4. `In` is `O``in` + `S``in` where + is element-wise +addition. `expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`) and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., @@ -1282,21 +1300,21 @@ and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., Informally, every index `Out` in the output array corresponds to an element `E` in the operand array, computed as follows: - - We use the batch dimensions in `Out` to look up a starting index from - `start_indices`. +- We use the batch dimensions in `Out` to look up a starting index from +`start_indices`. - - We use `start_index_map` to map the starting index (which may have size less - than operand.rank) to a "full" starting index into operand. +- We use `start_index_map` to map the starting index (which may have size less +than operand.rank) to a "full" starting index into operand. - - We dynamic-slice out a slice with size `slice_sizes` using the full starting - index. +- We dynamic-slice out a slice with size `slice_sizes` using the full starting +index. - - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. - Since all collapsed slice dimensions have to have bound 1 this reshape is - always legal. +- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. +Since all collapsed slice dimensions have to have bound 1 this reshape is +always legal. - - We use the offset dimensions in `Out` to index into this slice to get the - input element, `E`, corresponding to output index `Out`. +- We use the offset dimensions in `Out` to index into this slice to get the +input element, `E`, corresponding to output index `Out`. `index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples that follow. More interesting values for `index_vector_dim` does not @@ -1315,7 +1333,7 @@ the output shape, and maps it to an element in the input array in the following way:
- +
We first select an (`X`,`Y`) vector from the gather indices array using `G`. @@ -1334,7 +1352,7 @@ version of the example above using a "gather indices" array of shape `[4,5,2]` would translate indices like this:
- +
Again, this acts as a batch dynamic slice `G``0` and @@ -1343,27 +1361,27 @@ Again, this acts as a batch dynamic slice `G``0` and The gather operation in XLA generalizes the informal semantics outlined above in the following ways: - 1. We can configure which dimensions in the output shape are the offset - dimensions (dimensions containing `O``0`, `O``1` in - the last example). The output batch dimensions (dimensions containing - `G``0`, `G``1` in the last example) are defined to be - the output dimensions that are not offset dimensions. +1. We can configure which dimensions in the output shape are the offset +dimensions (dimensions containing `O``0`, `O``1` in +the last example). The output batch dimensions (dimensions containing +`G``0`, `G``1` in the last example) are defined to be +the output dimensions that are not offset dimensions. - 2. The number of output offset dimensions explicitly present in the output - shape may be smaller than the input rank. These "missing" dimensions, which - are listed explicitly as `collapsed_slice_dims`, must have a slice size of - `1`. Since they have a slice size of `1` the only valid index for them is - `0` and eliding them does not introduce ambiguity. +2. The number of output offset dimensions explicitly present in the output +shape may be smaller than the input rank. These "missing" dimensions, which +are listed explicitly as `collapsed_slice_dims`, must have a slice size of +`1`. Since they have a slice size of `1` the only valid index for them is +`0` and eliding them does not introduce ambiguity. - 3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last - example) may have fewer elements than the input array rank, and an explicit - mapping dictates how the index should be expanded to have the same rank as - the input. +3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last +example) may have fewer elements than the input array rank, and an explicit +mapping dictates how the index should be expanded to have the same rank as +the input. As a final example, we use (2) and (3) to implement `tf.gather_nd`:
- +
`G``0` and `G``1` are used to slice out a starting index @@ -1442,11 +1460,11 @@ dependency between the while loops. ``` result1 = while (condition, init = init_value) { - Infeed(shape) +Infeed(shape) } result2 = while (condition, init = result1) { - Infeed(shape) +Infeed(shape) } ``` @@ -1464,7 +1482,9 @@ Infeed of the device. Builds a constant literal on device rather than a potentially large host transfer. Creates a rank 1 array of values starting at zero and incrementing by -one. +one. For floating-point types, the produced array is equivalent to +`ConvertElementType(Iota(...))` where the `Iota` is of integral type and the +conversion is to the floating-point type. Arguments | Type | Semantics ---------------- | --------------- | ------------------------------------ @@ -1853,6 +1873,20 @@ non-deterministic. Therefore, the reduction function should not be overly sensitive to reassociation. See the discussion about associativity in the context of [`Reduce`](#reduce) for more details. +## ReplicaId + +See also +[`XlaBuilder::ReplicaId`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Returns the unique ID (U32 scalar) of the replica. + + `ReplicaId()` + +The unique ID of each replica is an unsigned integer in the interval `[0, N)`, +where `N` is the number of replicas. Since all the replicas are running the same +program, a `ReplicaId()` call in the program will return a different value on +each replica. + ## Reshape See also diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 2a0241af3ef359c4d1c6c1ab9319b5b293110f7a..7e22a32e545e4155545ffcfb9582187eadec3a82 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -141,7 +141,7 @@ namespace xla { /* static */ bool IndexUtil::IndexInBounds(const Shape& shape, absl::Span index) { - int64 rank = ShapeUtil::Rank(shape); + int64 rank = shape.rank(); if (rank != index.size()) { return false; } diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3b5fcd5274881cec31ecf906e3461685f82a1f4 --- /dev/null +++ b/tensorflow/compiler/xla/layout.cc @@ -0,0 +1,96 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { + +TileProto Tile::ToProto() const { + TileProto tile_proto; + for (int64 i : dimensions()) { + tile_proto.add_dimensions(i); + } + return tile_proto; +} + +string Tile::ToString() const { + return absl::StrCat("(", absl::StrJoin(dimensions(), ","), ")"); +} + +/* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) { + Layout layout; + layout.set_format(proto.format()); + layout.minor_to_major_.reserve(proto.minor_to_major_size()); + for (const int64 dimension : proto.minor_to_major()) { + layout.add_minor_to_major(dimension); + } + layout.set_max_sparse_elements(proto.max_sparse_elements()); + for (const TileProto& tile_proto : proto.tiles()) { + *layout.add_tiles() = Tile::CreateFromProto(tile_proto); + } + layout.set_element_size_in_bits(proto.element_size_in_bits()); + return layout; +} + +LayoutProto Layout::ToProto() const { + LayoutProto proto; + proto.set_format(format_); + proto.mutable_minor_to_major()->Reserve(minor_to_major_size()); + for (const int64 dimension : minor_to_major()) { + proto.add_minor_to_major(dimension); + } + proto.set_max_sparse_elements(max_sparse_elements_); + for (const Tile& tile : tiles()) { + *proto.add_tiles() = tile.ToProto(); + } + proto.set_element_size_in_bits(element_size_in_bits()); + return proto; +} + +string Layout::ToString() const { + // TODO(b/119839262): Emit tiles in string. + if (format() == SPARSE) { + return absl::StrCat("sparse{", max_sparse_elements(), "}"); + } else if (format() == DENSE) { + return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), "}"); + } else { + CHECK_EQ(format(), INVALID_FORMAT); + return "invalid{}"; + } +} + +bool Layout::operator==(const Layout& other) const { + return (other.format() == format() && + other.minor_to_major() == minor_to_major() && + other.element_size_in_bits() == element_size_in_bits() && + other.max_sparse_elements() == max_sparse_elements() && + other.tiles() == tiles()); +} + +std::ostream& operator<<(std::ostream& out, const Tile& tile) { + out << tile.ToString(); + return out; +} + +std::ostream& operator<<(std::ostream& out, const Layout& layout) { + out << layout.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h new file mode 100644 index 0000000000000000000000000000000000000000..313368c39e4c976fc481941eb17325101f2ba69a --- /dev/null +++ b/tensorflow/compiler/xla/layout.h @@ -0,0 +1,187 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_H_ +#define TENSORFLOW_COMPILER_XLA_LAYOUT_H_ + +#include + +#include "absl/types/span.h" + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Describes a tile used in tiling-based layout. Refer to +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details. +class Tile { + public: + Tile() = default; + explicit Tile(absl::Span dimensions) + : dimensions_(dimensions.begin(), dimensions.end()) {} + + // De/Serialize a Tile to and from a TileProto. + static Tile CreateFromProto(const TileProto& tile_proto) { + return Tile(AsInt64Slice(tile_proto.dimensions())); + } + TileProto ToProto() const; + + bool operator==(const Tile& other) const { + return dimensions() == other.dimensions(); + } + bool operator!=(const Tile& other) const { return !(*this == other); } + + string ToString() const; + + // Returns the bound of the tile in the given dimension index. + int64 dimension(int i) const { return dimensions_.at(i); } + + // Returns the dimensions of the tile. + const std::vector& dimensions() const { return dimensions_; } + + private: + // The bounds of the tile. + std::vector dimensions_; +}; + +class Layout { + public: + Layout() = default; + + // Constructs a dense layout with the given minor-to-major order. + explicit Layout(absl::Span minor_to_major) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()) {} + + // Constructs a dense tiled layout with the given minor-to-major order and + // tiles. + Layout(absl::Span minor_to_major, absl::Span tiles) + : format_(DENSE), + minor_to_major_(minor_to_major.begin(), minor_to_major.end()), + tiles_(tiles.begin(), tiles.end()) {} + + // Construct a shape from a LayoutProto. + static Layout CreateFromProto(const LayoutProto& proto); + + // Returns a LayoutProto representation of the Layout. + LayoutProto ToProto() const; + + // Returns a human-readable string that represents this layout. + string ToString() const; + + bool operator==(const Layout& other) const; + bool operator!=(const Layout& other) const { return !(*this == other); } + + // The following methods mirror the protobuf generated code interface for the + // message LayoutProto. This enabled easy migration of this data structure + // from a proto to a proper C++ class. + // + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing the format. + Format format() const { return format_; } + Layout& set_format(Format value) { + format_ = value; + return *this; + } + + // Methods for accessing the minor-to-major array. + int minor_to_major_size() const { return minor_to_major_.size(); } + int64 minor_to_major(int index) const { return minor_to_major_.at(index); } + Layout& set_minor_to_major(int index, int64 value) { + minor_to_major_.at(index) = value; + return *this; + } + Layout& add_minor_to_major(int64 value) { + minor_to_major_.push_back(value); + return *this; + } + Layout& clear_minor_to_major() { + minor_to_major_.clear(); + return *this; + } + const std::vector& minor_to_major() const { return minor_to_major_; } + std::vector* mutable_minor_to_major() { return &minor_to_major_; } + + // Methods for accessing the tile field. + int tiles_size() const { return tiles_.size(); } + const Tile& tiles(int index) const { return tiles_.at(index); } + Tile* mutable_tiles(int index) { return &tiles_.at(index); } + Tile* add_tiles() { + tiles_.push_back(Tile()); + return &tiles_.back(); + } + Layout& clear_tiles() { + tiles_.clear(); + return *this; + } + const std::vector& tiles() const { return tiles_; } + std::vector* mutable_tiles() { return &tiles_; } + + // Methods for accessing the int64 fields. + int64 max_sparse_elements() const { return max_sparse_elements_; } + Layout& set_max_sparse_elements(int64 value) { + max_sparse_elements_ = value; + return *this; + } + int64 element_size_in_bits() const { return element_size_in_bits_; } + Layout& set_element_size_in_bits(int64 value) { + element_size_in_bits_ = value; + return *this; + } + + void Swap(Layout* other) { + using std::swap; + swap(*this, *other); + } + + void Clear() { + format_ = INVALID_FORMAT; + minor_to_major_.clear(); + max_sparse_elements_ = 0; + element_size_in_bits_ = 0; + } + + public: + // The format of this layout. + Format format_ = INVALID_FORMAT; + + // Sequence of dimension numbers, from minor (fastest varying index) to major + // (slowest varying index). + std::vector minor_to_major_; + + // The maximum number of elements that can be stored for SPARSE formats. This + // can be used to determine the maximum size in bytes of arrays stored in + // memory. This field must be zero unless the format is SPARSE. + int64 max_sparse_elements_ = 0; + + // The number of bits used to store an individual array element. + int64 element_size_in_bits_ = 0; + + // The tiles used in tiling-based layout. + std::vector tiles_; +}; + +std::ostream& operator<<(std::ostream& out, const Tile& Tile); +std::ostream& operator<<(std::ostream& out, const Layout& layout); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_H_ diff --git a/tensorflow/compiler/xla/layout_test.cc b/tensorflow/compiler/xla/layout_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb6abd3f6523b978e72b21ec082ae06973e86243 --- /dev/null +++ b/tensorflow/compiler/xla/layout_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/layout.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class LayoutTest : public ::testing::Test {}; + +TEST_F(LayoutTest, ToString) { + EXPECT_EQ(Layout().ToString(), "invalid{}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(123).ToString(), + "sparse{123}"); + EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); + EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(), + "{3,2,1,0}"); + EXPECT_EQ( + Layout({1, 0}, {Tile({2, 55})}).set_element_size_in_bits(42).ToString(), + "{1,0}"); +} + +TEST_F(LayoutTest, StreamOut) { + { + std::ostringstream oss; + oss << Tile({7, 8}); + EXPECT_EQ(oss.str(), "(7,8)"); + } + + { + std::ostringstream oss; + oss << Layout({0, 1, 2}); + EXPECT_EQ(oss.str(), "{0,1,2}"); + } +} + +TEST_F(LayoutTest, SparseLayoutMaxElements) { + EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), + 101); +} + +TEST_F(LayoutTest, Equality) { + EXPECT_EQ(Layout(), Layout()); + const std::vector empty_dims; + EXPECT_EQ(Layout(empty_dims), Layout(empty_dims)); + EXPECT_NE(Layout(), Layout(empty_dims)); + EXPECT_EQ(Layout({0, 1, 2, 3}), Layout({0, 1, 2, 3})); + EXPECT_NE(Layout({0, 1, 2, 3}), Layout({0, 1, 2})); + EXPECT_EQ(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 44})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}, {Tile({42, 45})})); + EXPECT_NE(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2, 3})); + EXPECT_EQ(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(33)); + EXPECT_NE(Layout({0, 1, 2}).set_element_size_in_bits(33), + Layout({0, 1, 2}).set_element_size_in_bits(7)); + EXPECT_EQ(Layout().set_format(SPARSE), Layout().set_format(SPARSE)); + EXPECT_EQ(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(42)); + EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42), + Layout().set_format(SPARSE).set_max_sparse_elements(24)); +} + +TEST_F(LayoutTest, LayoutToFromProto) { + // Round-trips a Layout through proto de/serialization. + auto expect_unchanged = [](const Layout& layout) { + EXPECT_EQ(layout, Layout::CreateFromProto(layout.ToProto())); + }; + + expect_unchanged(Layout()); + expect_unchanged(Layout({1, 3, 2, 0})); + expect_unchanged(Layout().set_format(SPARSE)); + expect_unchanged(Layout().set_format(SPARSE).set_max_sparse_elements(123)); + expect_unchanged(Layout({0, 1}).set_element_size_in_bits(42)); + expect_unchanged(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index dbb81381acde645f08639737b6e7b6f6ad971f9b..2fe9b56c6bdffb931726f60ab75081361b43ebb4 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -41,15 +41,13 @@ namespace { // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets // minor_to_major to the value that represents the default layout. -void SetDefaultLayoutToContainer( - tensorflow::protobuf::RepeatedField* - minor_to_major) { +void SetDefaultLayoutToContainer(std::vector* minor_to_major) { // The default XLA layout is major-to-minor (dim 0 is major). // For more information on XLA layouts, see: // https://www.tensorflow.org/performance/xla/shapes const int64 size = minor_to_major->size(); for (int64 i = 0; i < size; ++i) { - minor_to_major->Set(i, size - 1 - i); + (*minor_to_major)[i] = size - 1 - i; } } @@ -94,9 +92,8 @@ namespace { Layout CreateDefaultLayoutForRank(int64 rank) { Layout layout; layout.set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = layout.mutable_minor_to_major(); - minor_to_major->Resize(rank, 0); + std::vector* minor_to_major = layout.mutable_minor_to_major(); + minor_to_major->resize(rank, 0); SetDefaultLayoutToContainer(minor_to_major); return layout; } @@ -104,13 +101,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } // namespace /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { - if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) { + if (shape.IsOpaque() || shape.IsToken()) { // Opaque and token types have empty layouts. return Layout(); } // A Layout proto corresponds to a single array, not a tuple. - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); return CreateDefaultLayoutForRank(shape.dimensions_size()); } @@ -131,17 +128,16 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) { - if (ShapeUtil::IsTuple(*shape)) { + if (shape->IsTuple()) { // Tuple shape. for (auto& element_shape : *shape->mutable_tuple_shapes()) { SetToDefaultLayout(&element_shape); } shape->clear_layout(); - } else if (ShapeUtil::IsArray(*shape)) { + } else if (shape->IsArray()) { shape->mutable_layout()->set_format(DENSE); - tensorflow::protobuf::RepeatedField* - minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); - minor_to_major->Resize(shape->dimensions_size(), 0); + auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); + minor_to_major->resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); } else { // Opaque, token types etc. have no layout. @@ -164,7 +160,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ Status LayoutUtil::ValidateLayoutInShape( const Shape& shape, bool allow_missing_layouts) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { // Tuple shape. if (shape.has_layout()) { return InvalidArgument("tuple should not have a layout field"); @@ -174,7 +170,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { ValidateLayoutInShape(element_shape, allow_missing_layouts)); } return Status::OK(); - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { if (!shape.has_layout()) { if (allow_missing_layouts) { return Status::OK(); @@ -196,11 +192,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout, const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (!ShapeUtil::IsArray(shape)) { + if (!shape.IsArray()) { if (layout.minor_to_major_size() != 0) { return InvalidArgument( "shape of primitive type %s should not have a non-trivial layout", @@ -210,25 +206,24 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { - return InvalidArgument( - "Layout has an invalid format (%d) in layout {%s}, shape {%s}", - layout.format(), layout.ShortDebugString(), shape.ShortDebugString()); + return InvalidArgument("Layout has an invalid format (%d)", + layout.format()); } if (layout.format() == DENSE) { - if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { + if (layout.minor_to_major_size() != shape.rank()) { return InvalidArgument( "layout minor_to_major field contains %d elements, " "but shape is rank %d: {%s}; shape: %s", - layout.minor_to_major_size(), ShapeUtil::Rank(shape), + layout.minor_to_major_size(), shape.rank(), absl::StrJoin(layout.minor_to_major(), ", "), shape.ShortDebugString()); } - std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); - for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { + std::vector dimensions_in_layout(shape.rank(), false); + for (int64 i = 0; i < shape.rank(); ++i) { int64 dim = layout.minor_to_major(i); - if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { + if (dim < 0 || dim >= shape.rank()) { return InvalidArgument( "layout minor_to_major field has out-of-bounds value: %s", HumanString(layout)); @@ -260,8 +255,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) { - return ShapeUtil::IsArray(shape) && shape.has_layout() && - IsDense(shape.layout()); + return shape.IsArray() && shape.has_layout() && IsDense(shape.layout()); } /* static */ bool LayoutUtil::IsDense(const Layout& layout) { @@ -281,8 +275,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) { - return ShapeUtil::IsArray(shape) && shape.has_layout() && - IsSparse(shape.layout()); + return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout()); } /* static */ bool LayoutUtil::IsSparse(const Layout& layout) { @@ -295,11 +288,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::HasLayout(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { // Tuple shape: all subshapes must have a layout. - return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), - [](const Shape& s) { return HasLayout(s); }); - } else if (!ShapeUtil::IsArray(shape)) { + return absl::c_all_of(shape.tuple_shapes(), + [](const Shape& s) { return HasLayout(s); }); + } else if (!shape.IsArray()) { // Opaque, token types etc. ignore layout. return true; } @@ -316,7 +309,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) { - return protobuf_util::ProtobufEquals(lhs, rhs); + return lhs == rhs; } /* static */ absl::Span LayoutUtil::MinorToMajor( @@ -358,22 +351,18 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ string LayoutUtil::HumanString(const Layout& layout) { - if (IsSparse(layout)) { - return absl::StrCat("sparse{", layout.max_sparse_elements(), "}"); - } - CHECK(IsDense(layout)); - return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}"); + return layout.ToString(); } namespace { // Internal helper for recursively copying layouts. Status CopyLayoutInternal(const Shape& src, Shape* dst) { - if (ShapeUtil::IsTuple(src) != ShapeUtil::IsTuple(*dst)) { + if (src.IsTuple() != dst->IsTuple()) { return InvalidArgument( "cannot copy layout from shape: shape structure differs"); } - if (ShapeUtil::IsTuple(src)) { + if (src.IsTuple()) { if (ShapeUtil::TupleElementCount(src) != ShapeUtil::TupleElementCount(*dst)) { return InvalidArgument( @@ -385,7 +374,7 @@ Status CopyLayoutInternal(const Shape& src, Shape* dst) { } } else { if (src.has_layout()) { - if (ShapeUtil::Rank(src) != ShapeUtil::Rank(*dst)) { + if (src.rank() != dst->rank()) { return InvalidArgument("cannot copy layout from shape: ranks differs"); } TF_RETURN_IF_ERROR( @@ -407,9 +396,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { - if (ShapeUtil::IsTuple(lhs)) { - if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) != - ShapeUtil::TupleElementCount(rhs)) { + if (lhs.IsTuple()) { + if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) != + ShapeUtil::TupleElementCount(rhs)) { return false; } for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { @@ -418,8 +407,8 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { } } return true; - } else if (ShapeUtil::IsArray(lhs)) { - return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && + } else if (lhs.IsArray()) { + return lhs.rank() == rhs.rank() && LayoutUtil::Equal(lhs.layout(), rhs.layout()); } else { // Layouts of non-array and non-tuple shapes is ignored. @@ -435,7 +424,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { positions_in_layout.push_back( PositionInContainer(layout.minor_to_major(), dim)); } - std::sort(positions_in_layout.begin(), positions_in_layout.end()); + absl::c_sort(positions_in_layout); for (size_t i = 1; i < positions_in_layout.size(); ++i) { if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) { return false; @@ -444,11 +433,6 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return true; } -std::ostream& operator<<(std::ostream& out, const Layout& layout) { - out << LayoutUtil::HumanString(layout); - return out; -} - /*static*/ size_t LayoutUtil::Hash(const Layout& layout) { using tensorflow::hash; using tensorflow::Hash64Combine; diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6c298e57252449ce3f1f9055436e918f2d9f17f1..609dba67bcdbcb11be0906b7d87a52a17ba0dfbd 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" @@ -195,8 +196,6 @@ class LayoutUtil { TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; -std::ostream& operator<<(std::ostream& out, const Layout& layout); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 12ce2d2d7c6fa8c590035f9ff2af50001ccf80d8..4cc94c270cd64eb19761cc1044861c7d185b7888 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -317,17 +317,6 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } -TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { - EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), - 101); -} - -TEST_F(LayoutUtilTest, StreamOut) { - std::ostringstream oss; - oss << LayoutUtil::MakeLayout({0, 1, 2}); - EXPECT_EQ(oss.str(), "{0,1,2}"); -} - TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) { Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); auto status = diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 8f480c1f1079b4e1a5be53958ebdf6e004ad9ebe..8600e8752cfbe072407391559d210d0b49bea511 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -29,10 +29,12 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -107,7 +109,7 @@ Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { const Shape& subshape = shape.tuple_shapes(i); @@ -118,7 +120,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { piece->emplace_back(std::move(child_piece)); } - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { if (allocate_arrays) { if (LayoutUtil::IsSparseArray(shape)) { // For sparse arrays, the buffer must be of the size of the maximum @@ -129,7 +131,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); piece->set_sparse_indices( - new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); + new SparseIndexArray(max_sparse_elements, shape.rank())); } else { piece->set_buffer(new char[piece->size_bytes()]); } @@ -187,7 +189,7 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) { Literal literal(shape); literal.root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { - if (ShapeUtil::IsArray(piece->subshape())) { + if (piece->subshape().IsArray()) { memset(piece->untyped_data(), 0, piece->size_bytes()); } }); @@ -208,16 +210,15 @@ template Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, absl::Span src_base, absl::Span dest_base, absl::Span copy_size) { - TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); - TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); + TF_RET_CHECK(src_literal.shape().rank() == src_base.size()); + TF_RET_CHECK(shape().rank() == dest_base.size()); auto linear_index = [](const Shape& shape, absl::Span multi_index) { return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); }; - if (ShapeUtil::Rank(src_literal.shape()) == 0 || - ShapeUtil::Rank(shape()) == 0) { + if (src_literal.shape().rank() == 0 || shape().rank() == 0) { // If any of the two shapes are scalars, we can just call the StridedCopy() // directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); @@ -312,7 +313,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, proto_element = &proto_element->tuple_literals(i); } - if (ShapeUtil::IsTuple(piece->subshape())) { + if (piece->subshape().IsTuple()) { if (proto_element->tuple_literals_size() != ShapeUtil::TupleElementCount(piece->subshape())) { return InvalidArgument( @@ -326,7 +327,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } - CHECK(ShapeUtil::IsArray(piece->subshape())); + CHECK(piece->subshape().IsArray()); TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); return Status::OK(); @@ -336,7 +337,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, } std::vector Literal::DecomposeTuple() { - CHECK(ShapeUtil::IsTuple(shape())); + CHECK(shape().IsTuple()); std::vector elements; for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), @@ -375,7 +376,7 @@ void CopyElementsBetween(absl::Span dest, if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; } - std::vector index(ShapeUtil::Rank(dest_shape)); + std::vector index(dest_shape.rank()); do { dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; @@ -392,7 +393,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { memcpy(buffer(), src.buffer(), src.size_bytes()); } else { TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); - std::vector origin(ShapeUtil::Rank(subshape()), 0); + std::vector origin(subshape().rank(), 0); switch (subshape().element_type()) { #define COPY_ELEMENTS(XLA_T, NATIVE_T) \ case (XLA_T): \ @@ -412,6 +413,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { COPY_ELEMENTS(F32, float); COPY_ELEMENTS(F64, double); COPY_ELEMENTS(C64, complex64); + COPY_ELEMENTS(C128, complex128); COPY_ELEMENTS(PRED, bool); #undef COPY_ELEMENTS default: @@ -438,7 +440,7 @@ Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, } return root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { - if (!ShapeUtil::IsArray(piece->subshape())) { + if (!piece->subshape().IsArray()) { return Status::OK(); } @@ -477,7 +479,7 @@ Status Literal::MoveFrom(Literal&& src_literal, src_literal.root_piece_->ForEachSubpiece( [&](const ShapeIndex& src_index, const Piece& src_piece) { - if (!ShapeUtil::IsArray(src_piece.subshape())) { + if (!src_piece.subshape().IsArray()) { return; } @@ -504,8 +506,8 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, absl::Span src_base, absl::Span dest_base, absl::Span copy_size) { - TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); - TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) + TF_RET_CHECK(shape().IsArray()) << ShapeUtil::HumanString(shape()); + TF_RET_CHECK(src_literal.shape().IsArray()) << ShapeUtil::HumanString(src_literal.shape()); TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); @@ -549,6 +551,9 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, case C64: return CopySliceFromInternal(src_literal, src_base, dest_base, copy_size); + case C128: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case PRED: return CopySliceFromInternal(src_literal, src_base, dest_base, copy_size); @@ -562,8 +567,8 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, } void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(element_count(), values.bits()); CHECK_EQ(shape().element_type(), PRED); for (int64 i = 0; i < static_cast(values.bits()); ++i) { @@ -592,7 +597,7 @@ Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { ShapeUtil::ForEachSubshape( result.shape(), [this, &result](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { TF_CHECK_OK(result.CopyFrom(*this, /*dest_shape_index=*/index, /*src_shape_index=*/index)); @@ -603,7 +608,7 @@ Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { StatusOr LiteralBase::Broadcast( const Shape& result_shape, absl::Span dimensions) const { - if (!ShapeUtil::IsArray(shape())) { + if (!shape().IsArray()) { return InvalidArgument("Broadcast only supports arrays."); } @@ -643,13 +648,12 @@ StatusOr LiteralBase::Broadcast( StatusOr LiteralBase::Reshape( absl::Span dimensions) const { - if (!ShapeUtil::IsArray(shape())) { + if (!shape().IsArray()) { return InvalidArgument("Reshape does not support tuples."); } Literal output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { - output = - Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); + output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank())); } else { output = Clone(); } @@ -671,8 +675,8 @@ StatusOr LiteralBase::Reshape( } Literal LiteralBase::Transpose(absl::Span permutation) const { - CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; - CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) + CHECK(shape().IsArray()) << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, shape().rank())) << "Given permutation is not a permutation of dimension numbers"; // To transpose the array, we just permute the dimensions and layout, and // do a straight memory copy of the raw data set. @@ -711,10 +715,10 @@ template Literal LiteralBase::SliceInternal( const Shape& result_shape, absl::Span start_indices) const { Literal result_literal(result_shape); - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + DimensionVector new_indices(result_shape.rank()); result_literal.EachCell( [&](absl::Span indices, NativeT /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + for (int64 i = 0; i < result_shape.rank(); ++i) { new_indices[i] = indices[i] + start_indices[i]; } NativeT value = Get(new_indices); @@ -725,10 +729,10 @@ Literal LiteralBase::SliceInternal( Literal LiteralBase::Slice(absl::Span start_indices, absl::Span limit_indices) const { - CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; + CHECK(shape().IsArray()) << "tuple is not supported for slice"; DimensionVector result_dimensions; - for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { + for (int64 dnum = 0; dnum < shape().rank(); ++dnum) { CHECK_GE(start_indices[dnum], 0); CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) << "dnum = " << dnum; @@ -768,6 +772,8 @@ Literal LiteralBase::Slice(absl::Span start_indices, return SliceInternal(result_shape, start_indices); case C64: return SliceInternal(result_shape, start_indices); + case C128: + return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); @@ -816,6 +822,10 @@ string LiteralBase::GetAsString(absl::Span multi_index, complex64 c = Get(multi_index, shape_index); return StrCat("(", c.real(), ", ", c.imag(), ")"); } + case C128: { + complex128 c = Get(multi_index, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } default: LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); } @@ -870,6 +880,11 @@ string LiteralBase::GetSparseElementAsString( GetSparseElement(sparse_element_number, shape_index); return StrCat("(", c.real(), ", ", c.imag(), ")"); } + case C128: { + complex128 c = + GetSparseElement(sparse_element_number, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } default: LOG(FATAL) << "Invalid element type for sparse arrays: " << PrimitiveType_Name(subshape.element_type()); @@ -906,7 +921,7 @@ size_t LiteralBase::Hash() const { ShapeUtil::ForEachSubshape( shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (!ShapeUtil::IsArray(subshape)) { + if (!subshape.IsArray()) { return; } @@ -998,6 +1013,9 @@ void LiteralBase::Piece::SortSparseElements() { case C64: SortSparseElementsInternal(); break; + case C128: + SortSparseElementsInternal(); + break; case F16: SortSparseElementsInternal(); break; @@ -1028,20 +1046,21 @@ string ShapeToString(bool print_layout, const Shape& shape) { } void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces); + bool print_shape, bool print_layout, + std::vector* pieces); void TupleToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_layout, - std::vector* pieces) { + const ShapeIndex& shape_index, bool print_shape, + bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - pieces->push_back(ShapeToString(print_layout, subshape)); - pieces->push_back(" (\n"); + pieces->push_back("(\n"); std::vector tuple_pieces; for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { ShapeIndex element_index = shape_index; element_index.push_back(i); std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); + ToStringHelper(literal, element_index, print_shape, print_layout, + &element_pieces); tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); } pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); @@ -1049,11 +1068,13 @@ void TupleToStringHelper(const LiteralBase& literal, } void SparseArrayToStringHelper(const LiteralBase& literal, - const Shape& subshape, bool print_layout, - std::vector* pieces) { - pieces->push_back(ShapeToString(print_layout, subshape)); + const Shape& subshape, bool print_shape, + bool print_layout, std::vector* pieces) { + if (print_shape) { + pieces->push_back(ShapeToString(print_layout, subshape)); + } pieces->push_back("{"); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); int64 num_elements = literal.sparse_element_count(); for (int64 i = 0; i < num_elements; ++i) { if (i > 0) { @@ -1073,10 +1094,10 @@ void SparseArrayToStringHelper(const LiteralBase& literal, } void DenseArrayToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_layout, - std::vector* pieces) { + const ShapeIndex& shape_index, bool print_shape, + bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); std::function dimensions, std::vector*)> to_string_recursive = [&](absl::Span dimensions, @@ -1135,7 +1156,7 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } }; - if (rank > 1) { + if (print_shape) { pieces->push_back(ShapeToString(print_layout, subshape)); pieces->push_back(" "); } @@ -1146,19 +1167,23 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces) { + bool print_shape, bool print_layout, + std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); CHECK(LayoutUtil::HasLayout(literal.shape())); CHECK(LayoutUtil::HasLayout(subshape)); - if (ShapeUtil::IsTuple(subshape)) { - TupleToStringHelper(literal, shape_index, print_layout, pieces); - } else if (ShapeUtil::IsToken(subshape)) { + if (subshape.IsTuple()) { + TupleToStringHelper(literal, shape_index, print_shape, print_layout, + pieces); + } else if (subshape.IsToken()) { pieces->push_back("token"); } else if (LayoutUtil::IsSparseArray(subshape)) { - SparseArrayToStringHelper(literal, subshape, print_layout, pieces); + SparseArrayToStringHelper(literal, subshape, print_shape, print_layout, + pieces); } else { CHECK(LayoutUtil::IsDenseArray(subshape)); - DenseArrayToStringHelper(literal, shape_index, print_layout, pieces); + DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout, + pieces); } } @@ -1169,10 +1194,27 @@ int64 LiteralBase::sparse_element_count() const { return sparse_indices()->index_count(); } -string LiteralBase::ToString(bool print_layout) const { +string LiteralBase::ToString() const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, /*print_shape=*/true, + /*print_layout=*/false, &pieces); + return absl::StrJoin(pieces, ""); +} + +string LiteralBase::ToStringWithoutShape() const { std::vector pieces; CHECK(LayoutUtil::HasLayout(this->shape())); - ToStringHelper(*this, {}, print_layout, &pieces); + ToStringHelper(*this, {}, /*print_shape=*/false, + /*print_layout=*/false, &pieces); + return absl::StrJoin(pieces, ""); +} + +string LiteralBase::ToStringWithLayout() const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, /*print_shape=*/true, + /*print_layout=*/true, &pieces); return absl::StrJoin(pieces, ""); } @@ -1193,7 +1235,7 @@ namespace { template Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, const ConverterType& converter) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); + CHECK(src_literal.shape().IsArray()); Literal result_literal(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType())); @@ -1208,7 +1250,24 @@ Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, } template -Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { +typename std::enable_if<(std::is_same::value) && + (std::is_same::value || + std::is_same::value), + Literal>::type +ConvertBetweenNativeTypes(const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { + return NativeDestT(static_cast(src)); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +template +typename std::enable_if<(!std::is_same::value) || + (!std::is_same::value && + !std::is_same::value), + Literal>::type +ConvertBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1252,22 +1311,6 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } -template -Literal ConvertToC64(const LiteralBase& src_literal) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); - Literal result_literal( - ShapeUtil::ChangeElementType(src_literal.shape(), C64)); - using NativeSrcT = - typename primitive_util::PrimitiveTypeToNative::type; - absl::Span src_data = src_literal.data(); - absl::Span dest_data = result_literal.data(); - int64 num_elements = src_literal.element_count(); - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = complex64(static_cast(src_data[i]), 0); - } - return result_literal; -} - template Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); @@ -1297,9 +1340,11 @@ StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, bitcast); CONVERT_IF_TYPES_MATCH(PRED) CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S16) CONVERT_IF_TYPES_MATCH(S32) CONVERT_IF_TYPES_MATCH(S64) CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U16) CONVERT_IF_TYPES_MATCH(U32) CONVERT_IF_TYPES_MATCH(U64) CONVERT_IF_TYPES_MATCH(F16) @@ -1308,10 +1353,15 @@ StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, CONVERT_IF_TYPES_MATCH(BF16) #undef CONVERT_IF_TYPES_MATCH case C64: - if (!bitcast) { - return ConvertToC64(src_literal); + if (bitcast) { + break; } - break; + return ConvertIfTypesMatch(src_literal, false); + case C128: + if (bitcast) { + break; + } + return ConvertIfTypesMatch(src_literal, false); // Other types are not yet supported. default: break; @@ -1324,7 +1374,7 @@ StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, StatusOr ConvertSwitch(const LiteralBase& literal, PrimitiveType primitive_dest_type, bool bitcast) { - TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); + TF_RET_CHECK(literal.shape().IsArray()); if (literal.shape().element_type() == primitive_dest_type) { return literal.Clone(); } @@ -1335,9 +1385,11 @@ StatusOr ConvertSwitch(const LiteralBase& literal, bitcast); CONVERT_IF_DEST_TYPE_MATCHES(PRED) CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S16) CONVERT_IF_DEST_TYPE_MATCHES(S32) CONVERT_IF_DEST_TYPE_MATCHES(S64) CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U16) CONVERT_IF_DEST_TYPE_MATCHES(U32) CONVERT_IF_DEST_TYPE_MATCHES(U64) CONVERT_IF_DEST_TYPE_MATCHES(F16) @@ -1377,7 +1429,7 @@ StatusOr LiteralBase::BitcastConvert( } StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { - if (!ShapeUtil::IsTuple(dest_shape)) { + if (!dest_shape.IsTuple()) { return Convert(dest_shape.element_type()); } std::vector elements; @@ -1409,7 +1461,7 @@ StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { template bool LiteralBase::Piece::EqualElementsInternal( const LiteralBase::Piece& other, std::vector* multi_index) const { - if (multi_index->size() == ShapeUtil::Rank(subshape())) { + if (multi_index->size() == subshape().rank()) { return (Get(*multi_index) == other.Get(*multi_index)); } for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { @@ -1459,6 +1511,8 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { return EqualElementsInternal(other, &multi_index); case C64: return EqualElementsInternal(other, &multi_index); + case C128: + return EqualElementsInternal(other, &multi_index); default: LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " << PrimitiveType_Name(subshape().element_type()); @@ -1472,7 +1526,7 @@ bool LiteralBase::operator==(const LiteralBase& other) const { return root_piece().ForEachSubpieceWithBool( [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } @@ -1502,7 +1556,7 @@ static bool AllElementsEqualValue(absl::Span data, bool LiteralBase::IsAll(int8 value) const { return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } @@ -1570,7 +1624,7 @@ bool LiteralBase::IsAll(int8 value) const { bool LiteralBase::IsAllFloat(float value) const { return root_piece().ForEachSubpieceWithBool( [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } @@ -1602,6 +1656,9 @@ bool LiteralBase::IsAllComplex(complex64 value) const { case C64: return AllElementsEqualValue(root_piece().data(), value); + case C128: + return AllElementsEqualValue(root_piece().data(), + value); default: return false; } @@ -1610,7 +1667,7 @@ bool LiteralBase::IsAllComplex(complex64 value) const { bool LiteralBase::IsAllFirst() const { return root_piece().ForEachSubpieceWithBool( [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { + if (!piece.subshape().IsArray()) { return true; } @@ -1681,6 +1738,11 @@ bool LiteralBase::IsAllFirst() const { auto data = piece.data(); return AllElementsEqualValue(data, data[0]); } + + case C128: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } default: return false; } @@ -1694,11 +1756,11 @@ bool LiteralBase::IsAllFirst() const { } bool LiteralBase::IsR1Iota() const { - if (!ShapeUtil::IsArray(shape())) { + if (!shape().IsArray()) { return false; } - if (ShapeUtil::Rank(shape()) != 1) { + if (shape().rank() != 1) { return false; } @@ -1730,6 +1792,8 @@ bool LiteralBase::IsR1Iota() const { return Get({idx}) == static_cast(idx); case C64: return Get({idx}) == complex64(idx, 0.0f); + case C128: + return Get({idx}) == complex128(idx, 0.0f); case PRED: return Get({idx}) == idx; // token, opaque, tuple, etc. are all not iota. @@ -1749,7 +1813,7 @@ bool LiteralBase::IsR1Iota() const { } bool LiteralBase::IsZero(absl::Span indices) const { - CHECK(ShapeUtil::IsArray(shape())); + CHECK(shape().IsArray()); switch (shape().element_type()) { case U8: return Get(indices) == 0; @@ -1773,6 +1837,8 @@ bool LiteralBase::IsZero(absl::Span indices) const { return Get(indices) == 0.0; case C64: return Get(indices) == complex64(0.0f, 0.0f); + case C128: + return Get(indices) == complex128(0.0f, 0.0f); case F16: return Get(indices) == static_cast(0.0f); case BF16: @@ -1860,6 +1926,12 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { proto->add_c64s(value.imag()); } break; + case C128: + for (complex128 value : data()) { + proto->add_c128s(value.real()); + proto->add_c128s(value.imag()); + } + break; case TUPLE: case TOKEN: // Nothing to do but assign the shape which is done above. @@ -1872,12 +1944,12 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { } const void* LiteralBase::Piece::untyped_data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); return buffer(); } void* LiteralBase::Piece::untyped_data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); return buffer(); } @@ -1908,14 +1980,12 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { if (LayoutUtil::IsSparseArray(subshape())) { // Compute the number of elements (indices) in the sparse shape and reserve // the necessary space in spare_indices. - TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0) - << "Scalar shapes cannot be sparse"; - TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0) + TF_RET_CHECK(subshape().rank() != 0) << "Scalar shapes cannot be sparse"; + TF_RET_CHECK(proto.sparse_indices_size() % subshape().rank() == 0) << "Unexpected number of indices in proto (" << proto.sparse_indices_size() << ") for shape of rank " - << ShapeUtil::Rank(subshape()); - const int64 index_count = - proto.sparse_indices_size() / ShapeUtil::Rank(subshape()); + << subshape().rank(); + const int64 index_count = proto.sparse_indices_size() / subshape().rank(); sparse_indices()->Resize(index_count); // Copy the indices from the proto into the SparseIndexArray object. @@ -1994,7 +2064,17 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { for (int64 i = 0; i < complex_data.size(); ++i) { complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; } - } break; + break; + } + case C128: { + auto complex_data = data(); + TF_RET_CHECK(proto.c128s_size() == complex_data.size() * 2); + for (int64 i = 0; i < complex_data.size(); ++i) { + complex_data[i] = + complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)}; + } + break; + } case TUPLE: return InvalidArgument("Should not be called on tuple shapes: %s", ShapeUtil::HumanString(subshape())); @@ -2040,8 +2120,8 @@ int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { } string LiteralBase::GetR1U8AsString() const { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(shape().element_type(), U8); return string(absl::bit_cast(data().data()), ShapeUtil::ElementsIn(shape())); @@ -2055,7 +2135,7 @@ void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, << ShapeUtil::HumanString(src_piece->subshape()) << "dest_piece has shape: " << ShapeUtil::HumanString(dest_piece->subshape()); - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { const Shape& subshape = shape.tuple_shapes(i); @@ -2066,7 +2146,7 @@ void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, dest_piece->emplace_back(std::move(child_piece)); } - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { dest_piece->set_buffer(src_piece->buffer()); } else { // If the shape is neither an array nor tuple, then it must be @@ -2142,7 +2222,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, : MutableLiteralBase() { shape_ = absl::make_unique(shape); CHECK(LayoutUtil::HasLayout(*shape_)); - CHECK(!ShapeUtil::IsTuple(*shape_)); + CHECK(!shape_->IsTuple()); root_piece_ = new Piece(); root_piece_->set_buffer(const_cast(src_buf_ptr)); @@ -2169,14 +2249,14 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal, : LiteralBase(), root_piece_(&literal.piece(view_root)) {} void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { - CHECK(ShapeUtil::IsTuple(shape)); + CHECK(shape.IsTuple()); for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { const Shape& subshape = shape.tuple_shapes(i); auto child_piece = Piece(); child_piece.set_subshape(&subshape); - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { BuildPieceSubtree(subshape, &child_piece); } @@ -2186,7 +2266,7 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) : LiteralBase(), shape_(absl::make_unique(shape)) { - CHECK(ShapeUtil::IsArray(*shape_)); + CHECK(shape_->IsArray()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = Piece(); @@ -2197,7 +2277,7 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, const Shape& shape) : LiteralBase(), shape_(absl::make_unique(shape)) { - CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(shape_->IsTuple()); CHECK(!ShapeUtil::IsNestedTuple(*shape_)); CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); root_piece_ = Piece(); @@ -2206,7 +2286,7 @@ BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, for (int i = 0; i < src_buf_ptrs.size(); ++i) { const auto& src_shape = shape_->tuple_shapes(i); - CHECK(ShapeUtil::IsArray(src_shape)); + CHECK(src_shape.IsArray()); root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index fa9a71af4ceb998a7a289443cbef70eb52cb1a11..c418be895d6c3faa6a85ca2c73c6f42b0a021104 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -92,9 +92,20 @@ class LiteralBase { // array. string GetR1U8AsString() const; - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; + // Returns a string representation of the literal value. The Shape of the + // literal is a prefix of the literal value in the string. + + // Warning: this function can take minutes for multi-million + // element Literals. + string ToString() const; + + // Returns a string representation of the literal value which does *not* + // include the shape string. + string ToStringWithoutShape() const; + + // Returns a string representation of the literal value which includes the + // shape string with its layout.does *not* include the shape string. + string ToStringWithLayout() const; // Gets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. @@ -856,7 +867,7 @@ class BorrowingLiteral : public LiteralBase { template absl::Span LiteralBase::Piece::data() const { - DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); DCHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) << "Attempting to access " @@ -869,7 +880,7 @@ absl::Span LiteralBase::Piece::data() const { template absl::Span LiteralBase::Piece::data() { - DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK(subshape().IsArray()) << ShapeUtil::HumanString(subshape()); DCHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) << "Attempting to access " @@ -950,8 +961,12 @@ void MutableLiteralBase::AppendSparseElement( Piece& p = piece(shape_index); const Shape& subshape = p.subshape(); CHECK(LayoutUtil::IsSparseArray(subshape)); - int64 rank = ShapeUtil::Rank(subshape); + int64 rank = subshape.rank(); CHECK_EQ(multi_index.size(), rank); + for (int64 i = 0; i < rank; ++i) { + CHECK_GE(multi_index[i], 0); + CHECK_LT(multi_index[i], subshape.dimensions(i)); + } int64 last_element = p.sparse_indices()->index_count(); CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); p.sparse_indices()->Append(multi_index); @@ -966,7 +981,7 @@ void LiteralBase::EachCell( if (ShapeUtil::IsZeroElementArray(shape())) { return; } - std::vector indices(ShapeUtil::Rank(shape()), 0); + std::vector indices(shape().rank(), 0); do { per_cell(indices, Get(indices)); } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices))); @@ -974,8 +989,8 @@ void LiteralBase::EachCell( template inline void MutableLiteralBase::PopulateR1(absl::Span values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -986,8 +1001,8 @@ inline void MutableLiteralBase::PopulateR1(absl::Span values) { template void MutableLiteralBase::PopulateR2( std::initializer_list> values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 2); + CHECK(shape().IsArray()); + CHECK_EQ(shape().rank(), 2); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); @@ -1010,10 +1025,10 @@ void MutableLiteralBase::PopulateR2( template void MutableLiteralBase::PopulateFromArray(const Array& values) { - CHECK(ShapeUtil::IsArray(shape())); + CHECK(shape().IsArray()); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); - CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); + CHECK_EQ(shape().rank(), values.num_dimensions()); for (int dim = 0; dim < values.num_dimensions(); ++dim) { CHECK_EQ(values.dim(dim), shape().dimensions(dim)); } @@ -1042,7 +1057,7 @@ void MutableLiteralBase::PopulateSparse(SparseIndexArray indices, absl::Span values, bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); - int rank = ShapeUtil::Rank(shape()); + int rank = shape().rank(); CHECK_EQ(indices.rank(), rank); int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); CHECK_LE(indices.max_indices(), max_elements); @@ -1066,7 +1081,7 @@ template Status MutableLiteralBase::PopulateInternal(const FnType& generator, bool parallel) { const Shape& this_shape = shape(); - const int64 rank = ShapeUtil::Rank(this_shape); + const int64 rank = this_shape.rank(); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); @@ -1118,7 +1133,7 @@ Status MutableLiteralBase::PopulateParallel(const FnType& generator) { template void MutableLiteralBase::PopulateWithValue(NativeT value) { - CHECK(ShapeUtil::IsArray(shape())); + CHECK(shape().IsArray()); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); for (NativeT& element : data()) { diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index b044f0ad73f13a0599e77f1f43888bc974e31f73..9b3de75dd4e9d495778af86fb8fc07909ab4ba81 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -46,68 +46,116 @@ uint16 GetRawValue(Eigen::half val) { return val.x; } // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, - absl::Span multi_index) { +bool CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, + absl::Span multi_index) { + auto ulhs = absl::bit_cast(GetRawValue(lhs)); + auto urhs = absl::bit_cast(GetRawValue(rhs)); + return ulhs == urhs; +} + +// Templated comparator that specializes for float equality comparison with the +// bitwise helper above (this is the un-specialized fallback, to just use the +// default gunit implementation). +template +bool CompareEqual(NativeT lhs, NativeT rhs, + absl::Span multi_index) { + return lhs == rhs; +} + +// Specializations for floating types that do bitwise comparisons when equality +// comparison is requested. +template <> +bool CompareEqual(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(float lhs, float rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(double lhs, double rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +} +template <> +bool CompareEqual(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + return CompareEqual(lhs.real(), rhs.real(), multi_index) && + CompareEqual(lhs.imag(), rhs.imag(), multi_index); +} +template <> +bool CompareEqual(complex128 lhs, complex128 rhs, + absl::Span multi_index) { + return CompareEqual(lhs.real(), rhs.real(), multi_index) && + CompareEqual(lhs.imag(), rhs.imag(), multi_index); +} + +template +Status MakeBitwiseErrorStatus(NativeT lhs, NativeT rhs, + absl::Span multi_index) { auto ulhs = absl::bit_cast(GetRawValue(lhs)); auto urhs = absl::bit_cast(GetRawValue(rhs)); auto lhs_double = static_cast(lhs); auto rhs_double = static_cast(rhs); - if (ulhs != urhs) { return InvalidArgument( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a at array index %s", StrCat(absl::Hex(ulhs)), lhs_double, lhs_double, StrCat(absl::Hex(urhs)), rhs_double, rhs_double, LiteralUtil::MultiIndexAsString(multi_index)); - } - return Status::OK(); } -// Templated comparator that specializes for float equality comparison with the -// bitwise helper above (this is the un-specialized fallback, to just use the -// default gunit implementation). template -Status CompareEqual(NativeT lhs, NativeT rhs, - absl::Span multi_index) { - if (lhs == rhs) { - return Status::OK(); - } +Status MakeErrorStatus(NativeT lhs, NativeT rhs, + absl::Span multi_index) { return InvalidArgument( "first mismatch at array index %s:\n expected value: %s\n actual " "value: %s", LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs)); } -// Specializations for floating types that do bitwise comparisons when equality -// comparison is requested. template <> -Status CompareEqual(bfloat16 lhs, bfloat16 rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(Eigen::half lhs, Eigen::half rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(float lhs, float rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(float lhs, float rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(double lhs, double rhs, - absl::Span multi_index) { - return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); +Status MakeErrorStatus(double lhs, double rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, multi_index); } template <> -Status CompareEqual(complex64 lhs, complex64 rhs, - absl::Span multi_index) { - auto res = CompareEqual(lhs.real(), rhs.real(), multi_index); - if (!res.ok()) { - return res; +Status MakeErrorStatus(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + if (!CompareEqual(lhs.real(), rhs.real(), multi_index)) { + return MakeErrorStatus(lhs.real(), rhs.real(), multi_index); } - return CompareEqual(lhs.imag(), rhs.imag(), multi_index); + return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); +} +template <> +Status MakeErrorStatus(complex128 lhs, complex128 rhs, + absl::Span multi_index) { + if (!CompareEqual(lhs.real(), rhs.real(), multi_index)) { + return MakeErrorStatus(lhs.real(), rhs.real(), multi_index); + } + return MakeErrorStatus(lhs.imag(), rhs.imag(), multi_index); } // A recursive function which iterates through every index of expected and @@ -119,7 +167,11 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get(multi_index); NativeT actual_value = actual.Get(multi_index); - return CompareEqual(expected_value, actual_value, multi_index); + bool result = + CompareEqual(expected_value, actual_value, multi_index); + return result ? Status::OK() + : MakeErrorStatus(expected_value, actual_value, + multi_index); } Status result; @@ -134,53 +186,40 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, // Gets the total element count. For tuples, this is not the count of tuple // elements, but the sum of elements of each tuple element. int64 RecursiveElementCount(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); int64 total = 0; for (int64 i = 0; i < tuple_elements; ++i) { total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); } return total; - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { return ShapeUtil::ElementsIn(shape); } else { return 0; } } -// Returns whether the actual and expected values are mismatched with respect to -// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. +// Returns whether the given value is infinity. template -bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { - if (relaxed_nans) { - return !std::isnan(expected) && std::isnan(actual); - } else { - return std::isnan(expected) != std::isnan(actual); - } -} - -template <> -bool NanMismatch(complex64 expected, complex64 actual, - bool relaxed_nans) { - return NanMismatch(expected.real(), actual.real(), relaxed_nans) || - NanMismatch(expected.imag(), actual.imag(), relaxed_nans); +bool IsInf(NativeT val) { + return std::isinf(val); } template <> -bool NanMismatch(half expected, half actual, bool relaxed_nans) { - return NanMismatch(static_cast(expected), - static_cast(actual), relaxed_nans); +bool IsInf(half val) { + return std::isinf(static_cast(val)); } -// Returns whether the given value is infinity. +// Returns whether the given value is nan. template -bool IsInf(NativeT val) { - return std::isinf(val); +float IsNan(NativeT value) { + return std::isnan(value); } template <> -bool IsInf(half val) { - return std::isinf(static_cast(val)); +float IsNan(half value) { + return IsNan(static_cast(value)); } // Converts the given floating-point value to a string. @@ -194,6 +233,11 @@ string FpValueToString(complex64 value) { return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); } +template <> +string FpValueToString(complex128 value) { + return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); +} + // Returns the absolute value of the given floating point value. This function // is used instead of std::abs directly in order to allow type-dependent // implementations for NearComparator. @@ -273,7 +317,7 @@ class NearComparator { // If the shapes mismatch, we simply fail the expectation instead of // printing out data, as it's a type error rather than a value error. TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); - if (!ShapeUtil::IsArray(expected_.shape())) { + if (!expected_.shape().IsArray()) { return InvalidArgument("Expected array shape; got %s.", ShapeUtil::HumanString(expected_.shape())); } @@ -326,35 +370,59 @@ class NearComparator { // the given literal_index and keeps track of various mismatch statistics. template void CompareValues(T expected, T actual, int64 linear_index) { - const bool is_nan_mismatch = - NanMismatch(expected, actual, error_.relaxed_nans); float abs_error; float rel_error; - if (CompareEqual(expected, actual, {linear_index}).ok()) { + if (CompareEqual(expected, actual, {linear_index})) { abs_error = 0; rel_error = 0; - } else if (is_nan_mismatch) { - num_nan_mismatches_++; - // A nan mismatch is considered to have infinite error. rel_error is used - // for sorting a std::set of the top mismatchs, and a nan value here will - // result in undefined behavior because nan's do not satisfy the strict - // weak ordering requirement of std containers. - abs_error = std::numeric_limits::infinity(); - rel_error = std::numeric_limits::infinity(); + } else if (IsNan(expected) || IsNan(actual)) { + if ((!error_.relaxed_nans && IsNan(expected) != IsNan(actual)) || + (error_.relaxed_nans && !IsNan(expected) && IsNan(actual))) { + num_nan_mismatches_++; + // A nan mismatch is considered to have infinite error. rel_error is + // used for sorting a std::set of the top mismatchs, and a nan value + // here will result in undefined behavior because nan's do not satisfy + // the strict weak ordering requirement of std containers. + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); + } else { + abs_error = 0; + rel_error = 0; + } + } else if (IsInf(actual) && !IsInf(expected) && error_.fewer_infs_ok) { + // `fewer_infs_ok` gives us the option of comparing as though `actual` + // were float_max/min rather than inf. + T actual_finite = actual > T{0} ? std::numeric_limits::max() + : std::numeric_limits::lowest(); + abs_error = FpAbsoluteValue(actual_finite - expected); + + // Avoid division by 0 even though it's well-defined because ubsan can be + // configured to treat this as a fatal error. + if (expected != T{0}) { + rel_error = abs_error / FpAbsoluteValue(expected); + } else { + rel_error = std::numeric_limits::infinity(); + } } else if (IsInf(expected) || IsInf(actual)) { // If either the expected or actual value is infinity but not both, // then both absolute and relative error are regarded as inifity. - CHECK(!CompareEqual(expected, actual, {linear_index}).ok()); + CHECK(!CompareEqual(expected, actual, {linear_index})); abs_error = std::numeric_limits::infinity(); rel_error = std::numeric_limits::infinity(); } else { abs_error = FpAbsoluteValue(actual - expected); - rel_error = abs_error / FpAbsoluteValue(expected); + + // Avoid division by 0 even though it's well-defined because ubsan can be + // configured to treat this as a fatal error. + if (expected != T{0}) { + rel_error = abs_error / FpAbsoluteValue(expected); + } else { + rel_error = std::numeric_limits::infinity(); + } } const bool is_abs_mismatch = abs_error > error_.abs; const bool is_rel_mismatch = rel_error > error_.rel; - const bool is_mismatch = - is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); + const bool is_mismatch = is_abs_mismatch && is_rel_mismatch; // Update the error of the relative bucket only if the *absolute* error // bound is exceeded and vice versa. @@ -389,7 +457,7 @@ class NearComparator { mismatches_.data()[linear_index] = true; } - // For complex64 types, we compare real and imaginary parts individually. + // For complex types, we compare real and imaginary parts individually. void CompareValues(complex64 expected, complex64 actual, int64 linear_index) { bool mismatch = false; CompareValues(expected.real(), actual.real(), linear_index); @@ -412,6 +480,29 @@ class NearComparator { mismatches_.data()[linear_index] = mismatch; } + void CompareValues(complex128 expected, complex128 actual, + int64 linear_index) { + bool mismatch = false; + CompareValues(expected.real(), actual.real(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for real part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + CompareValues(expected.imag(), actual.imag(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for imag part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + if (mismatch == true) { + num_mismatches_++; + } + mismatches_.data()[linear_index] = mismatch; + } + // Compares the two literals elementwise. void CompareLiterals() { // Fast path optimization for the case were layouts match. @@ -425,7 +516,7 @@ class NearComparator { } return; } - std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); + std::vector multi_index(actual_.shape().rank(), 0); CompareLiteralsSlow(0, &multi_index); } @@ -620,6 +711,9 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { case C64: result = Equal(expected, actual, index, 0); break; + case C128: + result = Equal(expected, actual, index, 0); + break; case TUPLE: { for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { result.Update(EqualHelper(LiteralSlice(expected, {i}), @@ -642,12 +736,12 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback, const ShapeIndex& shape_index) { TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); - if (ShapeUtil::IsTuple(expected.shape())) { + if (expected.shape().IsTuple()) { Status return_status; for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { const auto expected_element = LiteralSlice(expected, {i}); @@ -683,26 +777,32 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, if (ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())) { + bool use_detailed_message = detailed_message.value_or( + ShapeUtil::ElementsIn(expected.shape()) >= 64); switch (expected.shape().element_type()) { case BF16: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F16: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F32: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case F64: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); break; case C64: return NearComparator::Compare( - expected, actual, error, detailed_message, miscompare_callback); + expected, actual, error, use_detailed_message, miscompare_callback); + break; + case C128: + return NearComparator::Compare( + expected, actual, error, use_detailed_message, miscompare_callback); break; default: LOG(FATAL) << "Unsupported primitive type in near comparator: " @@ -723,7 +823,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); } - if (ShapeUtil::IsTuple(expected)) { + if (expected.IsTuple()) { if (ShapeUtil::TupleElementCount(expected) != ShapeUtil::TupleElementCount(actual)) { return InvalidArgument( @@ -738,8 +838,8 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return AppendStatus(result, StrCat("mismatch in tuple index", i)); } } - } else if (ShapeUtil::IsArray(expected)) { - if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + } else if (expected.IsArray()) { + if (expected.rank() != actual.rank()) { return InvalidArgument("want rank of %s got rank of %s", ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); @@ -793,7 +893,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback) { VLOG(1) << "Expected literal:"; XLA_VLOG_LINES(1, expected.ToString()); diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h index 9e5bf7c1d062ef0f25d07a80d6ded8106df5dacc..23fff3fa348f1652eaec344da4c40ccf3ad1079a 100644 --- a/tensorflow/compiler/xla/literal_comparison.h +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -55,9 +55,10 @@ using MiscompareCallback = // being compared. // // If detailed_message is true, then the error message in the assertion result -// will contain a more detailed breakdown of mismatches. +// will contain a more detailed breakdown of mismatches. By default, we display +// a detailed message only for "large" inputs. Status Near(const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error, bool detailed_message, + const ErrorSpec& error, absl::optional detailed_message, const MiscompareCallback& miscompare_callback); // Calling ToString on a literal with over 100 million elements takes around diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 49363ad802ddb9520f89b53257216bc7ddaf8ff5..b54a71ae68218ef578535a913f5867d843236e32 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -98,42 +98,45 @@ class LiteralUtilTest : public ::testing::Test { TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - EXPECT_EQ("true", true_lit.ToString()); + EXPECT_EQ("pred[] true", true_lit.ToString()); auto false_lit = LiteralUtil::CreateR0(false); - EXPECT_EQ("false", false_lit.ToString()); + EXPECT_EQ("pred[] false", false_lit.ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - EXPECT_EQ("42", u32_lit.ToString()); + EXPECT_EQ("u32[] 42", u32_lit.ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - EXPECT_EQ("-999", s32_lit.ToString()); + EXPECT_EQ("s32[] -999", s32_lit.ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - EXPECT_EQ("3.14", f32_lit.ToString()); + EXPECT_EQ("f32[] 3.14", f32_lit.ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", f16_lit.ToString()); + EXPECT_EQ("f16[] 0.5", f16_lit.ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString()); + EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString()); + + auto c128_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); + EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", bf16_lit.ToString()); + EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString()); // 3.14 will be rounded to 3.14062 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.14062", bf16_lit_truncated.ToString()); + ASSERT_EQ("bf16[] 3.14062", bf16_lit_truncated.ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - EXPECT_EQ("9", bf16_lit_truncated2.ToString()); + EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - EXPECT_EQ("{1, 0, 1}", pred_vec.ToString()); + EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -210,8 +213,8 @@ TEST_F(LiteralUtilTest, TupleToString) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); - const string expected = R"((f32[], f32[2,2]) ( -1, + const string expected = R"(( +f32[] 1, f32[2,2] { { 1, 2 }, { 3, 4 } @@ -469,6 +472,21 @@ TEST_F(LiteralUtilTest, C64Equality) { EXPECT_NE(vector, vector_reversed); } +TEST_F(LiteralUtilTest, C128Equality) { + // Test equality with tuples. + auto vector = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + + // Tuple with the same elements. One element is shared with the original + // tuple, the other is a clone of the element in the original tuple. + auto vector_clone = + LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + EXPECT_EQ(vector, vector_clone); + + auto vector_reversed = + LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); + EXPECT_NE(vector, vector_reversed); +} + TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = LiteralUtil::CreateR0(0.0); auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); @@ -623,7 +641,7 @@ template class LiteralUtilTestTemplated : public ::testing::Test {}; using TestedTypes = ::testing::Types; -TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); +TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { // Make a non-integer for floating point types. @@ -836,6 +854,13 @@ TEST_F(LiteralUtilTest, PopulateR1C64) { EXPECT_EQ(output, expected); } +TEST_F(LiteralUtilTest, PopulateR1C128) { + Literal output(ShapeUtil::MakeShape(C128, {1})); + output.PopulateR1({{77, 88}}); + auto expected = LiteralUtil::CreateR1({{77, 88}}); + EXPECT_EQ(output, expected); +} + TEST_F(LiteralUtilTest, PopulateR2C64) { Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); @@ -897,6 +922,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { EXPECT_EQ(output, expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR2C128) { + Literal output(ShapeUtil::MakeShape(C128, {2, 2})); + output.PopulateWithValue({4, 2}); + auto expected = + LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); + EXPECT_EQ(output, expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output(ShapeUtil::MakeShape(F16, {})); half h(0.25f); @@ -1237,11 +1270,21 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); + auto s16 = LiteralUtil::CreateR4WithLayout({{ + {{10, 0, 12, 0}, {0, 15, 0, 17}}, + {{0, 19, 0, 21}, {22, 0, 24, 0}}, + {{26, 0, 28, 0}, {0, 31, 0, 33}}, + }}, layout_r4_dim0major_); auto s32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); + auto u16 = LiteralUtil::CreateR4WithLayout({{ + {{10, 0, 12, 0}, {0, 15, 0, 17}}, + {{0, 19, 0, 21}, {22, 0, 24, 0}}, + {{26, 0, 28, 0}, {0, 31, 0, 33}}, + }}, layout_r4_dim0major_); auto u32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, @@ -1298,9 +1341,19 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); - // clang-format on + auto c128 = LiteralUtil::CreateR4WithLayout({{ + {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, + {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, + {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, + }}, layout_r4_dim0major_); // clang-format on Literal conv; + conv = s8.Convert(U16).ConsumeValueOrDie(); + EXPECT_EQ(conv, u16); + + conv = s8.Convert(S16).ConsumeValueOrDie(); + EXPECT_EQ(conv, s16); + conv = s8.Convert(U32).ConsumeValueOrDie(); EXPECT_EQ(conv, u32); @@ -1352,12 +1405,26 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = f16.Convert(C64).ConsumeValueOrDie(); EXPECT_EQ(conv, c64); + conv = s32.Convert(S16).ConsumeValueOrDie(); + EXPECT_EQ(conv, s16); + + conv = s32.Convert(U16).ConsumeValueOrDie(); + EXPECT_EQ(conv, u16); + + conv = s32.Convert(C128).ConsumeValueOrDie(); + EXPECT_EQ(conv, c128); + + conv = f16.Convert(C128).ConsumeValueOrDie(); + EXPECT_EQ(conv, c128); + EXPECT_EQ(s32.Convert(TUPLE).status().code(), tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c128.Convert(F32).status().code(), + tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c128.Convert(S32).status().code(), + tensorflow::error::UNIMPLEMENTED); } TEST_F(LiteralUtilTest, BitcastConvert) { @@ -1642,7 +1709,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]})); Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements)); - ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); + ASSERT_TRUE(literal.shape().IsTuple()); ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3); EXPECT_EQ(literal.Get({}, /*shape_index=*/{0}), 1.0); @@ -1659,7 +1726,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) { Literal literal = Literal::MoveIntoTuple({}); - ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); + ASSERT_TRUE(literal.shape().IsTuple()); EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); } @@ -1719,7 +1786,8 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), - ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); + ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}), + ShapeUtil::MakeShape(C128, {})})); EXPECT_EQ(tuple.Get({}, {0}), 0.0); EXPECT_EQ(tuple.Get({0}, {1}), false); @@ -1727,6 +1795,7 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { EXPECT_EQ(tuple.Get({0, 0}, {2}), 0); EXPECT_EQ(tuple.Get({1, 0}, {2}), 0); EXPECT_EQ(tuple.Get({}, {3}), complex64(0.0f, 0.0f)); + EXPECT_EQ(tuple.Get({}, {4}), complex128(0.0, 0.0)); } TEST_F(LiteralUtilTest, ProtoRoundTrip) { @@ -1736,6 +1805,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto vector_int8 = LiteralUtil::CreateR1({-128, 0, 2, 4, 7, 56, 127}); auto vector_uint8 = LiteralUtil::CreateR1({128, 0, 2, 56, 127, 255}); auto vector_c64 = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_c128 = + LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); auto vector_bfloat16 = LiteralUtil::CreateR1( {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = @@ -1756,6 +1827,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_int8, to_from_proto(vector_int8)); EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8)); EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); + EXPECT_EQ(vector_c128, to_from_proto(vector_c128)); EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(tuple, to_from_proto(tuple)); @@ -1890,7 +1962,7 @@ TEST_F(LiteralUtilTest, SortSparseElements) { literal.AppendSparseElement({3, 4, 5}, 3.0); literal.AppendSparseElement({1, 2, 3}, 1.0); literal.SortSparseElements(); - EXPECT_EQ(literal.ToString(false), + EXPECT_EQ(literal.ToString(), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index bb5e5e61000d0aca6ab052ac87d2fbcd96e55f70..26b029c8d0c52e38510f9279def7c4af2904931d 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -62,7 +62,7 @@ Literal ConvertType(LiteralSlice literal) { ShapeUtil::ForEachSubshape( literal.shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { if (subshape.element_type() == primitive_util::NativeToPrimitiveType()) { auto src = literal.data(shape_index); @@ -106,12 +106,16 @@ Literal ConvertType(LiteralSlice literal) { switch (primitive_type) { case U8: return LiteralUtil::CreateR0(0); + case U16: + return LiteralUtil::CreateR0(0); case U32: return LiteralUtil::CreateR0(0); case U64: return LiteralUtil::CreateR0(0); case S8: return LiteralUtil::CreateR0(0); + case S16: + return LiteralUtil::CreateR0(0); case S32: return LiteralUtil::CreateR0(0); case S64: @@ -126,11 +130,10 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(0); case C64: return LiteralUtil::CreateR0(0); + case C128: + return LiteralUtil::CreateR0(0); case PRED: return LiteralUtil::CreateR0(false); - case S16: - case U16: - LOG(FATAL) << "u16/s16 literals not yet implemented"; case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 0"; case OPAQUE: @@ -164,6 +167,8 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(1); case C64: return LiteralUtil::CreateR0(1); + case C128: + return LiteralUtil::CreateR0(1); case PRED: return LiteralUtil::CreateR0(true); case S16: @@ -200,6 +205,8 @@ Literal ConvertType(LiteralSlice literal) { -std::numeric_limits::infinity()); case C64: LOG(FATAL) << "C64 element type has no minimum value"; + case C128: + LOG(FATAL) << "C128 element type has no minimum value"; case PRED: return LiteralUtil::CreateR0(false); case S16: @@ -344,6 +351,10 @@ Literal ConvertType(LiteralSlice literal) { new_literal.Set(to_multi_index, literal.Get(from_multi_index)); break; + case C128: + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); + break; default: LOG(FATAL) << "Unhandled primitive element type: " << PrimitiveType_Name(literal.shape().element_type()); @@ -355,7 +366,7 @@ Literal ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::GetFirstScalarLiteral( const LiteralSlice& literal) { - CHECK(ShapeUtil::IsArray(literal.shape())); + CHECK(literal.shape().IsArray()); CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0); switch (literal.shape().element_type()) { case PRED: @@ -392,6 +403,10 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(literal.GetFirstElement()); case U64: return LiteralUtil::CreateR0(literal.GetFirstElement()); + + case C128: + return LiteralUtil::CreateR0( + literal.GetFirstElement()); default: LOG(FATAL) << "Unhandled primitive type " << literal.shape().element_type(); diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index 4eab4fa4290c270697c00be20840cf4e85459183..bad65ac32018fafcc7634b989f1b4b0867aa5c0d 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/metric_table_report.h" -#include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "tensorflow/core/platform/logging.h" @@ -55,7 +55,7 @@ string MetricTableReport::MakeReport(double expected_metric_sum) { const auto metric_greater = [](const Entry& a, const Entry& b) { return a.metric > b.metric; }; - std::sort(entries_.begin(), entries_.end(), metric_greater); + absl::c_sort(entries_, metric_greater); // Create the report AppendLine(); @@ -117,7 +117,7 @@ std::vector MetricTableReport::MakeCategories( auto metric_sum_greater = [](const Category& a, const Category& b) { return a.metric_sum > b.metric_sum; }; - std::sort(categories.begin(), categories.end(), metric_sum_greater); + absl::c_sort(categories, metric_sum_greater); return categories; } @@ -249,7 +249,7 @@ string MetricTableReport::MetricString(double metric) { string output; // Copy leading non-digit characters unconditionally. // This picks up the leading sign. - while (!sp1.empty() && !isdigit(sp1[0])) { + while (!sp1.empty() && !absl::ascii_isdigit(sp1[0])) { output.push_back(sp1[0]); sp1.remove_prefix(1); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 0f86f9f35e105713aa3072a9ebf572d33d35d66d..339660cf44fd64fc5859e72255d63762fcf20efe 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -42,8 +42,7 @@ PackedLiteralReader::~PackedLiteralReader() { delete file_; } StatusOr PackedLiteralReader::Read(const Shape& shape, const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) - << " layout: " - << (layout == nullptr ? "" : layout->ShortDebugString()); + << " layout: " << (layout == nullptr ? "" : layout->ToString()); Shape literal_shape = shape; if (layout != nullptr) { TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc index 5b568888d14f21c1330556d017eafba6c8dd2228..91e71f5d1d02d135158d0dffc140c21cf8ea5e3a 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -37,7 +38,7 @@ limitations under the License. namespace xla { -static const char kWS[] = " \t\r\n"; // whitespace +static const char kWS[] = " \t\r\n"; // whitespace // The following struct represents an argv[]-style array, parsed // from data gleaned from the environment. @@ -104,7 +105,8 @@ static void ParseArgvFromString(const string& flag_str, EnvArgv* a) { // Set e to the index just past the end of the flag. size_t e = b; while (e != flag_str.size() && isascii(flag_str[e]) && - (strchr("-_", flag_str[e]) != nullptr || isalnum(flag_str[e]))) { + (strchr("-_", flag_str[e]) != nullptr || + absl::ascii_isalnum(flag_str[e]))) { e++; } if (e != flag_str.size() && flag_str[e] == '=' && diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index b16147e3be71771269d8b7a18528bef3a8c72d99..1eedddf72c1d393cb1b88e589881e24de02ad802 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -15,16 +15,35 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace primitive_util { +int SignificandWidth(PrimitiveType type) { + switch (type) { + case F32: + return std::numeric_limits::digits; + case F64: + return std::numeric_limits::digits; + case BF16: + return kBFloat16MantissaBits + 1; + case F16: + return 11; + default: + LOG(FATAL) << "Not a floating data type " << type; + } +} + bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16; } -bool IsComplexType(PrimitiveType type) { return type == C64; } +bool IsComplexType(PrimitiveType type) { return type == C64 || type == C128; } bool IsSignedIntegralType(PrimitiveType type) { return type == S8 || type == S16 || type == S32 || type == S64; @@ -64,6 +83,9 @@ int BitWidth(PrimitiveType type) { case C64: return 64; + case C128: + return 128; + case TUPLE: LOG(FATAL) << "TUPLE is an invalid type for BitWidth"; @@ -75,10 +97,27 @@ int BitWidth(PrimitiveType type) { } } +xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) { + switch (src_bitwidth) { + case 8: + return xla::U8; + case 16: + return xla::U16; + case 32: + return xla::U32; + case 64: + return xla::U64; + default: + return xla::PRIMITIVE_TYPE_INVALID; + } +} + PrimitiveType ComplexComponentType(PrimitiveType complex_type) { switch (complex_type) { case C64: return F32; + case C128: + return F64; default: LOG(FATAL) << "Primitive type is not complex: " << PrimitiveType_Name(complex_type); @@ -90,5 +129,65 @@ bool IsArrayType(PrimitiveType primitive_type) { primitive_type != OPAQUE && primitive_type != TOKEN; } +// Class to memoize the computation of +// absl::AsciiStrToLower(PrimitiveType_Name(p)) +// for all PrimitiveType values "p" +class PrimitiveTypeNameGenerator { + public: + PrimitiveTypeNameGenerator() { + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i)) { + lowercase_name_[i] = absl::AsciiStrToLower( + PrimitiveType_Name(static_cast(i))); + } + } + } + const string& LowercaseName(PrimitiveType t) { + return lowercase_name_[static_cast(t)]; + } + + private: + string lowercase_name_[PrimitiveType_ARRAYSIZE]; +}; + +const string& LowercasePrimitiveTypeName(PrimitiveType s) { + static auto* gen = new PrimitiveTypeNameGenerator(); + return gen->LowercaseName(s); +} + +namespace { + +// Returns a map from lower-case primitive type name to primitive type. +const std::unordered_map& GetPrimitiveTypeStringMap() { + static std::unordered_map* name_to_type = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i) && i != PRIMITIVE_TYPE_INVALID) { + auto value = static_cast(i); + (*map)[LowercasePrimitiveTypeName(value)] = value; + } + } + return map; + }(); + return *name_to_type; +} + +} // namespace + +StatusOr StringToPrimitiveType(absl::string_view name) { + const auto& map = GetPrimitiveTypeStringMap(); + auto found = map.find(string(name)); + if (found == map.end()) { + return InvalidArgument("Invalid element type string: \"%s\".", name); + } + return found->second; +} + +bool IsPrimitiveTypeName(absl::string_view name) { + const auto& map = GetPrimitiveTypeStringMap(); + auto found = map.find(string(name)); + return found != map.end(); +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 889e9a1ceca675689406d255d348c82c398563aa..295d353003276b4c1731f7d6a378fd1ae0288d3c 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -20,12 +20,19 @@ limitations under the License. #include +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace primitive_util { +// Returns the count of significand (mantissa) bits for float datatypes. +// For non-float datatypes, results in a LOG(FATAL). +int SignificandWidth(PrimitiveType type); + // The number of exponent bits in a BF16 value. const int kBFloat16ExponentBits = 8; @@ -123,6 +130,11 @@ inline PrimitiveType NativeToPrimitiveType() { return C64; } +template <> +inline PrimitiveType NativeToPrimitiveType() { + return C128; +} + bool IsFloatingPointType(PrimitiveType type); bool IsComplexType(PrimitiveType type); @@ -139,6 +151,8 @@ bool IsArrayType(PrimitiveType primitive_type); // Returns the number of bits in the representation for a given type. int BitWidth(PrimitiveType type); +PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth); + // Returns the real, imag component type underlying the given complex type. // LOG(FATAL)'s if complex_type is not complex. PrimitiveType ComplexComponentType(PrimitiveType complex_type); @@ -221,6 +235,22 @@ template <> struct PrimitiveTypeToNative { using type = complex64; }; + +template <> +struct PrimitiveTypeToNative { + using type = complex128; +}; + +// Returns the lower-case name of the given primitive type. +const string& LowercasePrimitiveTypeName(PrimitiveType s); + +// Returns the PrimitiveType matching the given name. The given name is expected +// to be lower-case. +StatusOr StringToPrimitiveType(absl::string_view name); + +// Returns true if the given name is a primitive type string (lower-case). +bool IsPrimitiveTypeName(absl::string_view name); + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util_test.cc b/tensorflow/compiler/xla/primitive_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f765d6da9ef65849fe8ede56ced7597d623cb59 --- /dev/null +++ b/tensorflow/compiler/xla/primitive_util_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/primitive_util.h" + +#include +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +TEST(PrimitiveUtilTest, StringToPrimitiveType) { + auto expect_ok_and_equal = [](const string& str, PrimitiveType expected) { + TF_ASSERT_OK_AND_ASSIGN(PrimitiveType actual, + primitive_util::StringToPrimitiveType(str)); + EXPECT_EQ(expected, actual); + }; + expect_ok_and_equal("f32", F32); + expect_ok_and_equal("tuple", TUPLE); + expect_ok_and_equal("pred", PRED); + expect_ok_and_equal("s32", S32); + + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("F32").status()); + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("Pred").status()); + EXPECT_IS_NOT_OK(primitive_util::StringToPrimitiveType("preD").status()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 63ac1c6649210cbae9e238a74e0a45fb8ee4da63..4afb21d5c8864c2974114af2de08df4106a13a8c 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -3,7 +3,8 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") +load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_python_default_plugins") py_library( name = "xla_client", @@ -17,6 +18,12 @@ py_library( ], ) +pyx_library( + name = "custom_call_for_test", + testonly = True, + srcs = ["custom_call_for_test.pyx"], +) + py_test( name = "xla_client_test", srcs = ["xla_client_test.py"], @@ -24,6 +31,7 @@ py_test( srcs_version = "PY2AND3", tags = ["no_oss"], deps = [ + ":custom_call_for_test", ":xla_client", "//tensorflow/python:platform_test", ], @@ -66,13 +74,18 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt/cc:xrt_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -85,6 +98,11 @@ tf_py_wrap_cc( "local_computation_builder.i", "//tensorflow/python:platform/base.i", ], + version_script = select({ + "//tensorflow:darwin": "pywrap_xla_exported_symbols.lds", + "//tensorflow:windows": None, + "//conditions:default": "pywrap_xla_version_script.lds", + }), deps = [ ":local_computation_builder", ":numpy_bridge", @@ -92,7 +110,5 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:cpu_plugin", - ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/service:gpu_plugin", - ]), + ] + xla_python_default_plugins(), ) diff --git a/tensorflow/compiler/xla/python/custom_call_for_test.pyx b/tensorflow/compiler/xla/python/custom_call_for_test.pyx new file mode 100644 index 0000000000000000000000000000000000000000..530dffd1755d8438f52569c223525000c97df6ea --- /dev/null +++ b/tensorflow/compiler/xla/python/custom_call_for_test.pyx @@ -0,0 +1,21 @@ +# distutils: language = c++ + +# Test case for defining a XLA custom call target in Cython, and registering +# it via the xla_client SWIG API. + +from cpython.pycapsule cimport PyCapsule_New + +cdef void test_subtract_f32(void* out_ptr, void** data_ptr) nogil: + cdef float a = ((data_ptr[0]))[0] + cdef float b = ((data_ptr[1]))[0] + cdef float* out = (out_ptr) + out[0] = a - b + + +cpu_custom_call_targets = {} + +cdef register_custom_call_target(fn_name, void* fn): + cdef const char* name = "xla._CPU_CUSTOM_CALL_TARGET" + cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL) + +register_custom_call_target(b"test_subtract_f32", (test_subtract_f32)) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index c0b57e7d26581662476fb64ddaedafe4d55d8619..ce4bd6f681b80a0c52579f62e3422be81d06076f 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -24,12 +24,16 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -98,7 +102,7 @@ int GetReplicaCount() { return g_replica_count; } -LocalClient* GetOrCreateLocalClient() { +StatusOr GetOrCreateLocalClient() { string* platform_name = GetPlatformNameString(); tensorflow::mutex_lock lock(g_local_client_mutex); if (g_local_client != nullptr) { @@ -107,15 +111,30 @@ LocalClient* GetOrCreateLocalClient() { LocalClientOptions options; options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie()); options.set_number_of_replicas(g_replica_count); - g_local_client = ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie(); + TF_ASSIGN_OR_RETURN(g_local_client, + ClientLibrary::GetOrCreateLocalClient(options)); CHECK(g_local_client != nullptr); return g_local_client; } +Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { + const char* name = "xla._CPU_CUSTOM_CALL_TARGET"; + if (!PyCapsule_IsValid(capsule, name)) { + return InvalidArgument( + "Argument to RegisterCpuCustomCallTargetRegistry was not a " + "xla._CPU_CUSTOM_CALL_TARGET capsule."); + } + void* fn_ptr = PyCapsule_GetPointer(capsule, name); + CHECK(fn_ptr != nullptr); + cpu::CustomCallTargetRegistry::Global()->Register( + std::string(fn_name.begin(), fn_name.end()), fn_ptr); + return Status::OK(); +} + Status TransferToInfeedLocal(const Literal& literal) { VLOG(1) << "Infeeding literal without replica number; shape: " << literal.shape(); - LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0); } @@ -123,7 +142,7 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number) { VLOG(1) << "Infeeding shape " << literal.shape() << " to replica number: " << replica_number; - LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); TF_ASSIGN_OR_RETURN(int device_ordinal, client->ReplicaNumberToDeviceOrdinal(replica_number)); return client->TransferToInfeedLocal(literal, device_ordinal); @@ -133,7 +152,7 @@ StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, int replica_number) { VLOG(1) << "Outfeeding literal from replica number: " << replica_number << " shape: " << shape; - LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); TF_ASSIGN_OR_RETURN(int device_ordinal, client->ReplicaNumberToDeviceOrdinal(replica_number)); return client->TransferFromOutfeedLocal(shape, device_ordinal); @@ -148,14 +167,19 @@ static StatusOr ToBuffer(LocalClient* client, /* static */ StatusOr LocalShapedBuffer::FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout) { - LocalClient* client = GetOrCreateLocalClient(); + const Literal& argument, const absl::optional& shape_with_layout, + int replica_number) { + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); + TF_ASSIGN_OR_RETURN(int device_ordinal, + client->ReplicaNumberToDeviceOrdinal(replica_number)); + VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " + << replica_number << "/" << device_ordinal; StatusOr buf = [&] { if (shape_with_layout) { Literal relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, relaid); + return ToBuffer(client, device_ordinal, relaid); } - return ToBuffer(client, /*device_ordinal=*/0, argument); + return ToBuffer(client, device_ordinal, argument); }(); TF_RETURN_IF_ERROR(buf.status()); return new LocalShapedBuffer(std::move(buf).ValueOrDie()); @@ -175,7 +199,7 @@ const Shape& LocalShapedBuffer::shape() const { } StatusOr LocalShapedBuffer::ToLiteral() const { - LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); return client->ShapedBufferToLiteral(*shaped_buffer()); } @@ -237,7 +261,6 @@ XrtAllocation::~XrtAllocation() { StatusOr XrtAllocation::FromLiteral( const Literal& argument, const string& session_target) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = argument.ToProto(); tensorflow::Scope root = tensorflow::Scope::NewRootScope(); @@ -311,67 +334,120 @@ CompiledLocalComputation::CompiledLocalComputation( StatusOr CompiledLocalComputation::Execute( absl::Span argument_handles) { - LocalClient* client = GetOrCreateLocalClient(); + if (num_replicas() != 1) { + return InvalidArgument( + "Attempted to execute computation with %d replicas using Execute()", + num_replicas()); + } + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + client->backend().computation_placer()->AssignDevices( + 1, /*computation_count=*/1)); + StatusOr result_buffer_status; + const int device_ordinal = device_assignment(0, 0); + VLOG(3) << "Replica 0 mapped to device ordinal for execution: " + << device_ordinal; + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + argument_buffers.push_back(handle->shaped_buffer()); + } + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + + result_buffer_status = executable_->Run(argument_buffers, options); + + if (!result_buffer_status.ok()) { + return InternalError( + "Failed running replica 0 (other replicas may have failed as well): " + "%s.", + result_buffer_status.status().ToString()); + } + return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); +} - VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; +StatusOr CompiledLocalComputation::ExecutePerReplica( + absl::Span> argument_handles) { + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); + const int num_devices = client->device_count(); + + if (argument_handles.size() != num_replicas()) { + return InvalidArgument( + "Attempted to execute with %d replicas when replica count is %d", + argument_handles.size(), num_devices); + } + if (argument_handles.size() > num_devices) { + return InvalidArgument( + "Attempted to execute with %d replicas when device count is %d", + argument_handles.size(), num_devices); + } - // Each replica populates a StatusOr result, but only the output value of - // replica zero is returned. - std::vector> results(GetReplicaCount()); - { + VLOG(1) << "Executing with " << num_replicas() << " replicas."; + + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + client->backend().computation_placer()->AssignDevices( + num_replicas(), /*computation_count=*/1)); + + std::vector> results(num_replicas()); + auto execute = [this, client, &device_assignment, &argument_handles, + &results](int replica) { + const int device_ordinal = device_assignment(replica, 0); + VLOG(3) << "Replica " << replica + << " mapped to device ordinal for execution: " << device_ordinal; + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles[replica].size()); + for (auto& handle : argument_handles[replica]) { + argument_buffers.push_back(handle->shaped_buffer()); + } + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + StatusOr result_buffer_status = + executable_->Run(argument_buffers, options); + + results[replica] = std::move(result_buffer_status); + }; + + if (num_replicas() == 1) { + // Fast-path if there is only one replica — run the computation on the + // current thread. + execute(0); + } else { + // TODO(phawkins): don't recreate the threadpool for each execution. tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", - GetReplicaCount()); - - for (int replica = 0; replica < GetReplicaCount(); ++replica) { - pool.Schedule([this, client, replica, &argument_handles, &results] { - StatusOr device_ordinal_status = - client->ReplicaNumberToDeviceOrdinal(replica); - if (!device_ordinal_status.ok()) { - results[replica] = device_ordinal_status.status(); - return; - } - const int device_ordinal = device_ordinal_status.ValueOrDie(); - VLOG(3) << "Replica " << replica - << " mapped to device ordinal for execution: " - << device_ordinal; - - std::vector argument_buffers; - argument_buffers.reserve(argument_handles.size()); - for (auto& handle : argument_handles) { - argument_buffers.push_back(handle->shaped_buffer()); - } - - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) - .ConsumeValueOrDie(); - - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); - StatusOr result_buffer_status = - executable_->Run(argument_buffers, options); - - results[replica] = std::move(result_buffer_status); - }); + num_replicas() - 1); + + for (int replica = 0; replica < num_replicas() - 1; ++replica) { + pool.Schedule([&execute, replica] { execute(replica); }); } + execute(num_replicas() - 1); } - for (int replica = 0; replica < GetReplicaCount(); ++replica) { - const auto& statusor = results[replica]; + std::vector wrapped_results(num_replicas()); + for (int replica = 0; replica < num_replicas(); ++replica) { + auto& statusor = results[replica]; if (!statusor.ok()) { return InternalError( "Failed running replica %d (other replicas may have failed as well): " "%s.", replica, statusor.status().ToString()); } + wrapped_results[replica] = + new LocalShapedBuffer(std::move(statusor).ValueOrDie()); } - return new LocalShapedBuffer(std::move(results[0]).ValueOrDie()); + return new LocalShapedBufferTuple(std::move(wrapped_results)); } static StatusOr GetReturnValueShape(const XlaComputation& computation) { @@ -466,7 +542,7 @@ StatusOr LocalComputation::Compile( argument_shape_pointers.push_back(&argument_shape); } - LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); ExecutableBuildOptions options; if (build_options != nullptr) { options = *build_options; @@ -578,6 +654,15 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return xla::ConstantLiteral(&builder_, literal); } +LocalOp LocalComputationBuilder::Iota(PrimitiveType element_type, int64 size) { + return xla::Iota(&builder_, element_type, size); +} + +LocalOp LocalComputationBuilder::BroadcastedIota(const Shape& shape, + int64 dimension) { + return xla::Iota(&builder_, shape, dimension); +} + LocalOp LocalComputationBuilder::Broadcast( const LocalOp& operand, absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); @@ -606,8 +691,20 @@ LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, return xla::Collapse(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { - return xla::CrossReplicaSum(operand.op()); +LocalOp LocalComputationBuilder::AllToAll( + const LocalOp& operand, int64 split_dimension, int64 concat_dimension, + int64 split_count, absl::Span replica_groups) { + std::vector rg(replica_groups.size()); + for (int i = 0; i < replica_groups.size(); ++i) { + rg.push_back(replica_groups[i]); + } + return xla::AllToAll(operand.op(), split_dimension, concat_dimension, + split_count, rg); +} + +LocalOp LocalComputationBuilder::CrossReplicaSum( + const LocalOp& operand, absl::Span replica_groups) { + return xla::CrossReplicaSum(operand.op(), replica_groups); } LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, @@ -714,6 +811,21 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, return xla::Call(&builder_, local_computation.computation(), xla_ops); } +LocalOp LocalComputationBuilder::CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, + const std::vector& operand_shapes_with_layout, + const string& opaque) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return xla::CustomCallWithLayout(&builder_, call_target_name, xla_ops, + shape_with_layout, + operand_shapes_with_layout, opaque); +} + LocalOp LocalComputationBuilder::Transpose( const LocalOp& operand, absl::Span permutation) { return xla::Transpose(operand.op(), permutation); @@ -799,6 +911,43 @@ LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, return xla::Sort(keys.op(), {values.op()}, dimension); } +LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) { + return xla::Cholesky(a.op()); +} + +LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { + XlaBuilder* builder = a.op().builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices)); + return xla::Tuple(builder, {qr.q, qr.r}); + }); +} + +LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a, + const LocalOp& b, + bool left_side, bool lower, + bool transpose_a, + bool conjugate_a) { + return xla::TriangularSolve(a.op(), b.op(), left_side, lower, transpose_a, + conjugate_a); +} + +LocalOp LocalComputationBuilder::Gather( + const LocalOp& input, const LocalOp& start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes) { + return xla::Gather(input.op(), start_indices.op(), dimension_numbers, + slice_sizes); +} + +LocalOp LocalComputationBuilder::Scatter( + const LocalOp& input, const LocalOp& scatter_indices, + const LocalOp& updates, const LocalComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return xla::Scatter(input.op(), scatter_indices.op(), updates.op(), + update_computation.computation(), dimension_numbers); +} + StatusOr LocalComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, @@ -913,7 +1062,7 @@ StatusOr DestructureLocalShapedBufferTuple( LocalShapedBuffer* local_shaped_buffer) { const Shape tuple_shape = local_shaped_buffer->shape(); - if (!ShapeUtil::IsTuple(tuple_shape)) { + if (!tuple_shape.IsTuple()) { return InvalidArgument( "Attemped to destructure a LocalShapedBuffer that did not have a tuple " "shape; shape: %s", @@ -960,7 +1109,7 @@ StatusOr DestructureXrtAllocationTuple( XrtAllocation* allocation, const string& session_target) { const Shape& tuple_shape = allocation->shape(); - if (!ShapeUtil::IsTuple(tuple_shape)) { + if (!tuple_shape.IsTuple()) { return InvalidArgument( "Attemped to destructure a LocalShapedBuffer that did not have a tuple " "shape; shape: %s", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index c9b7ae824a4e5dac3360de0f95859d7c1deb360f..e3af88f82559c32a7267a56c87d3bafda01b934d 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include + #include "absl/types/span.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -49,6 +51,11 @@ Status InitializePlatformName(const string& platform_name); // local XLA service has been instantiated yet or not. int GetReplicaCount(); +// Registers a 'fn_capsule' as a CPU custom call target. +// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name +// "xla._CPU_CUSTOM_CALL_TARGET". +Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule); + // Wraps the local client's infeed-transfer function. // // The default device ordinal (0) is used. @@ -71,7 +78,8 @@ StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, class LocalShapedBuffer { public: static StatusOr FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout); + const Literal& argument, const absl::optional& shape_with_layout, + int replica_number); LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); StatusOr ToLiteral() const; @@ -172,9 +180,19 @@ class CompiledLocalComputation { public: CompiledLocalComputation(std::unique_ptr executable); + int num_replicas() const { + return executable_->build_options().num_replicas(); + } + StatusOr Execute( absl::Span argument_handles); + // Execute on many replicas. Takes a sequence of argument lists (one argument + // list per replica) and returns a tuple of results (one result per replica). + // The number of argument lists must be equal to the replica count. + StatusOr ExecutePerReplica( + absl::Span > argument_handles); + private: std::unique_ptr executable_; }; @@ -279,6 +297,10 @@ class LocalComputationBuilder { LocalOp ConstantLiteral(const Literal& literal); + LocalOp Iota(PrimitiveType element_type, int64 size); + + LocalOp BroadcastedIota(const Shape& shape, int64 dimension); + LocalOp Broadcast(const LocalOp& operand, absl::Span broadcast_sizes); @@ -294,7 +316,12 @@ class LocalComputationBuilder { LocalOp Collapse(const LocalOp& operand, absl::Span dimensions); - LocalOp CrossReplicaSum(const LocalOp& operand); + LocalOp AllToAll(const LocalOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + absl::Span replica_groups); + + LocalOp CrossReplicaSum(const LocalOp& operand, + absl::Span replica_groups); LocalOp Slice(const LocalOp& operand, absl::Span start_indices, absl::Span limit_indices, @@ -345,6 +372,12 @@ class LocalComputationBuilder { LocalOp Call(const LocalComputation& local_computation, absl::Span operands); + LocalOp CustomCall(const string& call_target_name, + absl::Span operands, + const Shape& shape_with_layout, + const std::vector& operand_shapes_with_layout, + const string& opaque); + LocalOp Transpose(const LocalOp& operand, absl::Span permutation); @@ -387,6 +420,22 @@ class LocalComputationBuilder { LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values, int64 dimension); + LocalOp QR(const LocalOp& a, bool full_matrices); + + LocalOp Cholesky(const LocalOp& a); + + LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a); + + LocalOp Gather(const LocalOp& input, const LocalOp& start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes); + + LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices, + const LocalOp& updates, + const LocalComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 5c2538dcc36d93008382a517fd4dc680caaa4347..7b2f69d6ecf44f492f70351b38997530567b5277 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -34,6 +34,9 @@ limitations under the License. // PaddingConfig proto <- corresponding Python proto // ConvolutionDimensionNumbers proto <- corresponding Python proto // DotDimensionNumbers proto <- corresponding Python proto +// GatherDimensionNumbers proto <- corresponding Python proto +// ScatterDimensionNumbers proto <- corresponding Python proto +// Span <- sequence of ReplicaGroup Python proto // // Arrows indicate whether a conversion only ever occurs in one // direction, or whether it is maintained bidirectionally. @@ -167,8 +170,41 @@ bool HandleStringAttribute(PyObject* o, return true; // Handled string attribute, ok! } +bool HandleRepeatedInt64Attribute( + PyObject* o, const char* attr_name, + tensorflow::protobuf::RepeatedField* field) { + PyObject* seq = PyObject_GetAttrString(o, attr_name); + if (!seq) { + return false; + } + + int length = PySequence_Size(seq); + if (length == -1) { + Py_DECREF(seq); + return false; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(seq, i); + if (!item) { + Py_DECREF(seq); + return false; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(seq); + return false; + } + *field->Add() = dimension; + Py_DECREF(item); + } + Py_DECREF(seq); + return true; } -} + +} // namespace swig +} // namespace xla %} // Required to use PyArray_* functions. @@ -363,6 +399,37 @@ tensorflow::ImportNumpy(); $1 = temps; } +%typemap(in) absl::Span > + (std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + std::vector vec; + const int vec_size = PySequence_Size(o); + vec.reserve(vec_size); + for (int j = 0; j < vec_size; ++j) { + PyObject* vec_elt = PySequence_GetItem(o, j); + LocalShapedBuffer* lsbp; + if ((SWIG_ConvertPtr(vec_elt, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), + SWIG_POINTER_EXCEPTION)) == -1) { + Py_DECREF(vec_elt); + Py_DECREF(o); + SWIG_fail; + } + vec.push_back(lsbp); + Py_DECREF(vec_elt); + } + temps.push_back(vec); + Py_DECREF(o); + } + $1 = temps; +} + %typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { @@ -626,128 +693,27 @@ tensorflow::ImportNumpy(); %typemap(in) const DotDimensionNumbers& (DotDimensionNumbers dimension_numbers) { - int length; - - /* lhs_contracting_dimensions */ - PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( - $input, "lhs_contracting_dimensions"); - if (!lhs_contracting_dimensions) { + if (!HandleRepeatedInt64Attribute( + $input, "lhs_contracting_dimensions", + dimension_numbers.mutable_lhs_contracting_dimensions())) { SWIG_fail; } - - length = PySequence_Size(lhs_contracting_dimensions); - if (length == -1) { - Py_DECREF(lhs_contracting_dimensions); + if (!HandleRepeatedInt64Attribute( + $input, "rhs_contracting_dimensions", + dimension_numbers.mutable_rhs_contracting_dimensions())) { SWIG_fail; } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); - if (!item) { - Py_DECREF(lhs_contracting_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(lhs_contracting_dimensions); - SWIG_fail; - } - dimension_numbers.add_lhs_contracting_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(lhs_contracting_dimensions); - - /* rhs_contracting_dimensions */ - PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( - $input, "rhs_contracting_dimensions"); - if (!lhs_contracting_dimensions) { - SWIG_fail; - } - - length = PySequence_Size(rhs_contracting_dimensions); - if (length == -1) { - Py_DECREF(rhs_contracting_dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); - if (!item) { - Py_DECREF(rhs_contracting_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(rhs_contracting_dimensions); - SWIG_fail; - } - dimension_numbers.add_rhs_contracting_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(rhs_contracting_dimensions); - - /* lhs_batch_dimensions */ - PyObject* lhs_batch_dimensions = PyObject_GetAttrString( - $input, "lhs_batch_dimensions"); - if (!lhs_batch_dimensions) { + if (!HandleRepeatedInt64Attribute( + $input, "lhs_batch_dimensions", + dimension_numbers.mutable_lhs_batch_dimensions())) { SWIG_fail; } - - length = PySequence_Size(lhs_batch_dimensions); - if (length == -1) { - Py_DECREF(lhs_batch_dimensions); + if (!HandleRepeatedInt64Attribute( + $input, "rhs_batch_dimensions", + dimension_numbers.mutable_rhs_batch_dimensions())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); - if (!item) { - Py_DECREF(lhs_batch_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(lhs_batch_dimensions); - SWIG_fail; - } - dimension_numbers.add_lhs_batch_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(lhs_batch_dimensions); - - /* rhs_batch_dimensions */ - PyObject* rhs_batch_dimensions = PyObject_GetAttrString( - $input, "rhs_batch_dimensions"); - if (!rhs_batch_dimensions) { - SWIG_fail; - } - - length = PySequence_Size(rhs_batch_dimensions); - if (length == -1) { - Py_DECREF(rhs_batch_dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); - if (!item) { - Py_DECREF(rhs_batch_dimensions); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(rhs_batch_dimensions); - SWIG_fail; - } - dimension_numbers.add_rhs_batch_dimensions(dimension); - Py_DECREF(item); - } - Py_DECREF(rhs_batch_dimensions); - $1 = &dimension_numbers; } @@ -829,90 +795,108 @@ tensorflow::ImportNumpy(); } dimension_numbers.set_kernel_input_feature_dimension(value); - PyObject* o; - int length; - - o = PyObject_GetAttrString($input, "input_spatial_dimensions"); - if (!o) { + if (!HandleRepeatedInt64Attribute( + $input, "input_spatial_dimensions", + dimension_numbers.mutable_input_spatial_dimensions())) { SWIG_fail; } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); + if (!HandleRepeatedInt64Attribute( + $input, "kernel_spatial_dimensions", + dimension_numbers.mutable_kernel_spatial_dimensions())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_input_spatial_dimensions(dimension); - Py_DECREF(item); + if (!HandleRepeatedInt64Attribute( + $input, "output_spatial_dimensions", + dimension_numbers.mutable_output_spatial_dimensions())) { + SWIG_fail; } - Py_DECREF(o); - o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); - if (!o) { + $1 = &dimension_numbers; +} + +// GatherDimensionNumbers + +%typemap(in) const GatherDimensionNumbers& + (GatherDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "offset_dims", + dimension_numbers.mutable_offset_dims())) { SWIG_fail; } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); + if (!HandleRepeatedInt64Attribute( + $input, "collapsed_slice_dims", + dimension_numbers.mutable_collapsed_slice_dims())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_kernel_spatial_dimensions(dimension); - Py_DECREF(item); + if (!HandleRepeatedInt64Attribute( + $input, "start_index_map", + dimension_numbers.mutable_start_index_map())) { + SWIG_fail; } - Py_DECREF(o); - o = PyObject_GetAttrString($input, "output_spatial_dimensions"); - if (!o) { + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { SWIG_fail; } - length = PySequence_Size(o); - if (length == -1) { - Py_DECREF(o); + dimension_numbers.set_index_vector_dim(value); + + $1 = &dimension_numbers; +} + +// ScatterDimensionNumbers + +%typemap(in) const ScatterDimensionNumbers& + (ScatterDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "update_window_dims", + dimension_numbers.mutable_update_window_dims())) { SWIG_fail; } - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(o, i); - if (!item) { - Py_DECREF(o); - SWIG_fail; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(o); - SWIG_fail; - } - dimension_numbers.add_output_spatial_dimensions(dimension); - Py_DECREF(item); + if (!HandleRepeatedInt64Attribute( + $input, "inserted_window_dims", + dimension_numbers.mutable_inserted_window_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "scatter_dims_to_operand_dims", + dimension_numbers.mutable_scatter_dims_to_operand_dims())) { + SWIG_fail; + } + + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { + SWIG_fail; } - Py_DECREF(o); + dimension_numbers.set_index_vector_dim(value); $1 = &dimension_numbers; } +// Span + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + ReplicaGroup rgrp; + if (!HandleRepeatedInt64Attribute( + o, "replica_ids", + rgrp.mutable_replica_ids())) { + SWIG_fail; + } + temps.push_back(rgrp); + Py_DECREF(o); + } + $1 = temps; +} + + // ExecutableBuildOptions %typemap(in) const ExecutableBuildOptions* @@ -969,6 +953,12 @@ tensorflow::ImportNumpy(); } Py_DECREF(o); + int64 num_replicas; + if (!GetIntAttr($input, "num_replicas", &num_replicas)) { + SWIG_fail; + } + build_options.set_num_replicas(num_replicas); + $1 = &build_options; } } @@ -979,6 +969,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::InitializeReplicaCount; %unignore xla::swig::InitializePlatformName; %unignore xla::swig::GetReplicaCount; +%unignore xla::swig::RegisterCpuCustomCallTarget; %unignore xla::swig::TransferToInfeedLocal; %unignore xla::swig::TransferToInfeedLocalReplica; %unignore xla::swig::TransferFromOutfeedLocalReplica; @@ -998,6 +989,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::XrtAllocationTuple::size; %unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::CompiledLocalComputation::Execute; +%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; %unignore xla::swig::CompiledXrtComputation; %unignore xla::swig::CompiledXrtComputation::Execute; %unignore xla::swig::LocalComputation; @@ -1019,11 +1011,14 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Outfeed; %unignore xla::swig::LocalComputationBuilder::ConstantLiteral; %unignore xla::swig::LocalComputationBuilder::ConstantR0; +%unignore xla::swig::LocalComputationBuilder::Iota; +%unignore xla::swig::LocalComputationBuilder::BroadcastedIota; %unignore xla::swig::LocalComputationBuilder::Broadcast; %unignore xla::swig::LocalComputationBuilder::BroadcastInDim; %unignore xla::swig::LocalComputationBuilder::Pad; %unignore xla::swig::LocalComputationBuilder::Reshape; %unignore xla::swig::LocalComputationBuilder::Collapse; +%unignore xla::swig::LocalComputationBuilder::AllToAll; %unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; %unignore xla::swig::LocalComputationBuilder::Slice; %unignore xla::swig::LocalComputationBuilder::SliceInDim; @@ -1112,6 +1107,12 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Imag; %unignore xla::swig::LocalComputationBuilder::Conj; %unignore xla::swig::LocalComputationBuilder::Complex; +%unignore xla::swig::LocalComputationBuilder::Cholesky; +%unignore xla::swig::LocalComputationBuilder::QR; +%unignore xla::swig::LocalComputationBuilder::TriangularSolve; +%unignore xla::swig::LocalComputationBuilder::CustomCall; +%unignore xla::swig::LocalComputationBuilder::Gather; +%unignore xla::swig::LocalComputationBuilder::Scatter; %unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DestructureXrtAllocationTuple; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index b0aa024c7474cf8e6934432b2f364be464714999..52c5c621f7294c5da341879d15b77559fe870551 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -54,6 +54,8 @@ int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { return NPY_FLOAT64; case C64: return NPY_COMPLEX64; + case C128: + return NPY_COMPLEX128; case TUPLE: return NPY_OBJECT; default: @@ -89,6 +91,8 @@ PrimitiveType NumpyTypeToPrimitiveType(int np_type) { return F64; case NPY_COMPLEX64: return C64; + case NPY_COMPLEX128: + return C128; case NPY_OBJECT: return TUPLE; default: @@ -111,6 +115,7 @@ bool NumpyTypeIsValid(int np_type) { case NPY_FLOAT32: case NPY_FLOAT64: case NPY_COMPLEX64: + case NPY_COMPLEX128: case NPY_OBJECT: return true; default: @@ -123,7 +128,7 @@ PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); PyObject* dimensions; - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(shape); dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape)); for (int i = 0; i < num_elements; ++i) { @@ -132,7 +137,7 @@ PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); } } else { - int rank = ShapeUtil::Rank(shape); + int rank = shape.rank(); dimensions = PyTuple_New(rank); for (int i = 0; i < rank; ++i) { PyTuple_SET_ITEM(dimensions, i, @@ -345,7 +350,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { } PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); PyObject* tuple = PyTuple_New(num_elements); for (int i = 0; i < num_elements; i++) { @@ -354,7 +359,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { } return tuple; } else { - int rank = ShapeUtil::Rank(literal.shape()); + int rank = literal.shape().rank(); std::vector dimensions(rank); // NOLINT - PyArray requires a long* for (int i = 0; i < rank; i++) { dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); @@ -430,6 +435,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_COMPLEX64: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_COMPLEX128: + CopyNumpyArrayToLiteral(py_array, literal); + break; default: return InvalidArgument( "No XLA literal container for Numpy type number: %d", np_type); @@ -470,6 +478,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_COMPLEX64: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_COMPLEX128: + CopyLiteralToNumpyArray(literal, py_array); + break; default: LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; } diff --git a/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds b/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds new file mode 100644 index 0000000000000000000000000000000000000000..bce6c1acf8a1cc0005ca93e0466c5a0e29d880de --- /dev/null +++ b/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds @@ -0,0 +1 @@ +_PyInit__pywrap_xla diff --git a/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds b/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..d31cfce7be7b6accf05ef77f3485904099965afc --- /dev/null +++ b/tensorflow/compiler/xla/python/pywrap_xla_version_script.lds @@ -0,0 +1,6 @@ +xla { + global: + PyInit_*; + local: + *; +}; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index e5fba0d7acb838788f8e7e05a4634e807d9d21d0..37cae0e3b6b8635ca53e282994f0d078974df5a9 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -199,6 +199,7 @@ XLA_ELEMENT_TYPE_TO_DTYPE = { xla_data_pb2.F32: np.dtype('float32'), xla_data_pb2.F64: np.dtype('float64'), xla_data_pb2.C64: np.dtype('complex64'), + xla_data_pb2.C128: np.dtype('complex128'), xla_data_pb2.TUPLE: np.dtype(np.object), } @@ -222,24 +223,33 @@ class LocalBuffer(object): means the referent is in device memory. """ - def __init__(self, c_buffer, backend): + def __init__(self, c_buffer, backend, replica): self.c_buffer = c_buffer self._backend = backend + self._replica = replica if backend.backend_type == BackendType.XRT: self._delete = c_api.DeleteXrtAllocation else: self._delete = c_api.DeleteLocalShapedBuffer @staticmethod - def from_pyval(pyval, backend=XLA_LOCAL_BACKEND): + def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND): """Allocate and copy to XLA the given python value.""" pyval = require_numpy_array_layout(pyval) + num_replicas = get_replica_count() + if not 0 <= replica < num_replicas: + raise ValueError( + 'Attempt to place buffer on replica {} when the replica count is {}' + .format(replica, num_replicas)) if backend.backend_type == BackendType.XRT: + if replica != 0: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') cbuf = c_api.XrtAllocation.FromLiteral( pyval, _maybe_encode_string(backend.target)) else: - cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None) - return LocalBuffer(cbuf, backend) + cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None, replica) + return LocalBuffer(cbuf, backend, replica) def to_py(self): return self.c_buffer.ToLiteral() @@ -247,6 +257,9 @@ class LocalBuffer(object): def shape(self): return _wrap_shape(self.c_buffer.shape()) + def replica(self): + return self._replica + def delete(self): if self.c_buffer is not None: self._delete(self.c_buffer) @@ -263,7 +276,8 @@ class LocalBuffer(object): self.delete() size = result.size() destructured = tuple( - LocalBuffer(result.Release(i), backend=self._backend) + LocalBuffer( + result.Release(i), replica=self._replica, backend=self._backend) for i in xrange(size)) return destructured @@ -402,7 +416,7 @@ class Shape(object): assert mtm is None, self if mtm is not None: assert self.rank() == len(mtm), self - assert sorted(mtm) == range(len(mtm)), self + assert sorted(mtm) == list(range(len(mtm))), self def update_minor_to_major(self, minor_to_major): if not self.is_array(): @@ -445,6 +459,7 @@ class CompileOptions(object): self.dump_unoptimized_hlo_proto_to = None self.dump_per_pass_hlo_proto_to = None self.hlo_profile = False + self.num_replicas = get_replica_count() def transfer_to_infeed(value, replica_number=None): @@ -575,27 +590,97 @@ class LocalComputation(object): compile_options=compile_options, layout_fn=layout_fn) - def Execute(self, arguments=()): - """Execute with LocalBuffer arguments and return value.""" + def GetReturnValueShape(self): + return _wrap_shape(self._c_computation.GetReturnValueShape()) + + def Execute(self, arguments=(), check_for_deleted_args=True): + """Execute on one replica with LocalBuffer arguments and return value.""" + if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): + raise ValueError('Executing with deleted local buffer argument') + raw_args = [arg.c_buffer for arg in arguments] + output_buffer = self._c_computation.Execute(raw_args) + return LocalBuffer(output_buffer, backend=self._backend, replica=0) + + def ExecutePerReplica(self, arguments=None): + """Execute on many replicas with LocalBuffer arguments and return value. + + Args: + arguments: A sequence of sequences of LocalBuffers. The i'th inner + sequence comprises the arguments for execution on the i'th replica. + + Returns: + A list of the computation's outputs on each replica, as a LocalBuffer. If + a shallow sequence of arguments was passed in for `arguments`, then the + sole, zero'th replica's output is returned instead, as a LocalBuffer. + """ if not self._is_compiled: raise ValueError('Cannot execute an uncompiled local XLA computation.') - arguments = tuple(arguments) - if any(arg.is_deleted() for arg in arguments): - raise ValueError('Executing with deleted local buffer argument') - return LocalBuffer( - self._c_computation.Execute([arg.c_buffer for arg in arguments]), - backend=self._backend) + if arguments is None: + arguments = ((),) * get_replica_count() + else: + arguments = [list(replica_args) for replica_args in arguments] + + # Check arguments + for replica, replica_args in enumerate(arguments): + for arg in replica_args: + if arg.is_deleted(): + raise ValueError('Executing with deleted local buffer argument') + if arg.replica() != replica: + raise ValueError( + 'Executing on replica {} with argument from replica {}'.format( + replica, arg.replica())) + + # Pull out argument buffer handles + stripped_args = [ + [arg.c_buffer for arg in replica_args] for replica_args in arguments + ] + + # Execute + if self._backend.backend_type == BackendType.XRT: + if len(stripped_args) > 1: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + output_buffers = [self._c_computation.Execute(stripped_args[0])] + else: + output_buffer_tup = self._c_computation.ExecutePerReplica(stripped_args) + size = output_buffer_tup.size() + output_buffers = [output_buffer_tup.Release(i) for i in xrange(size)] + + # Wrap output handles in LocalBuffer instances + return tuple( + LocalBuffer(output_buffer, backend=self._backend, replica=replica) + for replica, output_buffer in enumerate(output_buffers)) def ExecuteWithPythonValues(self, arguments=()): - """Execute with Python values as arguments and return value.""" - arguments = tuple( - LocalBuffer.from_pyval(arg, backend=self._backend) for arg in arguments) + """Execute on one replica with Python values as arguments and output.""" + + def put(arg): + return LocalBuffer.from_pyval(arg, backend=self._backend) + + arguments = [put(arg) for arg in arguments] return self.Execute(arguments).to_py() + def ExecuteWithPythonValuesPerReplica(self, arguments): + """Execute on many replicas with Python values as arguments and output.""" + + def put(arg, replica): + return LocalBuffer.from_pyval(arg, replica, backend=self._backend) + + arguments = [[put(arg, replica) + for arg in replica_args] + for replica, replica_args in enumerate(arguments)] + return [out.to_py() for out in self.ExecutePerReplica(arguments)] + def __del__(self): self._delete(self._c_computation) +def _make_replica_group_proto(replica_group): + replica_group_proto = xla_data_pb2.ReplicaGroup() + replica_group_proto.replica_ids.extend(replica_group) + return replica_group_proto + + class ComputationBuilder(object): """XLA computation builder. @@ -754,6 +839,33 @@ class ComputationBuilder(object): return self.ParameterWithShape( Shape.from_pyval(value), name=name, parameter_num=parameter_num) + def Iota(self, dtype, size): + """Enqueues an iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + size: integer, the number of elements in the array. + + Returns: + A LocalOp representing the added iota constant. + """ + element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] + return self._client.Iota(element_type, size) + + def BroadcastedIota(self, dtype, shape, dimension): + """Enqueues a broadcasted iota constant onto the computation. + + Args: + dtype: expected numpy dtype of the output. + shape: tuple of integers, the expected output shape (dimensions). + dimension: positive integer, dimension along which to increment values. + + Returns: + A LocalOp representing the added broadcasted iota constant. + """ + xla_shape = Shape.array_shape(dtype, shape) + return self._client.BroadcastedIota(xla_shape, dimension) + def Broadcast(self, operand, sizes): """Enqueues a broadcast operation onto the computation. @@ -859,16 +971,60 @@ class ComputationBuilder(object): dimensions = tuple(range(ndim)) return self._client.Reshape(operand, dimensions, new_sizes) - def CrossReplicaSum(self, operand): + def AllToAll(self, + operand, + split_dimension, + concat_dimension, + replica_groups=None): + """AllToAll op. + + Args: + operand: LocalOp representing the input array + split_dimension: the dimension along which the operand is split + concat_dimension: the dimension along which the split blocks are + concatenated + replica_groups: optional, list of lists of ints encoding a partition of + the set {0, 1, ..., num_replicas} into equally-sized replica groups + within which the all-to-all is performed. If not supplied or None (the + default), all replicas belong to the same group. + + Returns: + A LocalOp that represents the all-to-all concatenation. + """ + if replica_groups is None: + replica_groups_protos = [] # special value for XLA API + else: + replica_groups = list(replica_groups) + replica_groups_protos = [ + _make_replica_group_proto(group) for group in replica_groups] + if not replica_groups: + split_count = get_replica_count() + else: + split_count = len(replica_groups[0]) + if not all(split_count == len(g) for g in replica_groups): + raise ValueError('Replica groups must be equally sized') + return self._client.AllToAll(operand, split_dimension, concat_dimension, + split_count, replica_groups_protos) + + def CrossReplicaSum(self, operand, replica_groups=None): """CrossReplicaSum op. Args: operand: the operand to sum across replica instances. + replica_groups: optional, list of lists of ints encoding a partition of + the set {0, 1, ..., num_replicas} into equally-sized replica groups + within which the cross-replica sum is performed. If not supplied or None + (the default), all replicas belong to the same group. Returns: - A LocalOp that has the sum of the value among all replicas. + A LocalOp that represents on each replica the sum of its group's values. """ - return self._client.CrossReplicaSum(operand) + if replica_groups is None: + replica_groups = [] # special value for XLA API + else: + replica_groups = [ + _make_replica_group_proto(group) for group in replica_groups] + return self._client.CrossReplicaSum(operand, replica_groups) def Collapse(self, operand, dimensions): """Collapse op.""" @@ -1025,6 +1181,31 @@ class ComputationBuilder(object): """ return self._client.Call(computation_to_apply.computation, operands) + def CustomCall(self, + call_target_name, + operands, + shape_with_layout, + operand_shapes_with_layout, + opaque=None): + """Enqueues a custom call operation onto the computation. + + Args: + call_target_name: the name of the function to call. + operands: an iterable of LocalOp. The number and types of operands must + match the arity of `operand_shapes_with_layout`. + shape_with_layout: the shape of the operator's output, with layout. + operand_shapes_with_layout: the shapes of `operands`, including the + expected layouts. + opaque: an opaque string passed to the backend. + + Returns: + A LocalOp representing the added custom call op. + """ + opaque = opaque or b'' + return self._client.CustomCall(call_target_name, operands, + shape_with_layout, + operand_shapes_with_layout, opaque) + def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. @@ -1334,6 +1515,32 @@ class ComputationBuilder(object): """Enqueues a key-value sort operation onto the computation.""" return self._client.SortKeyVal(keys, values, dimension) + def Cholesky(self, a): + """Enqueues a Cholesky decomposition onto the computation.""" + return self._client.Cholesky(a) + + def QR(self, a, full_matrices=True): + """Enqueues a QR decomposition onto the computation.""" + return self._client.QR(a, full_matrices) + + def TriangularSolve(self, a, b, left_side=False, lower=False, + transpose_a=False, conjugate_a=False): + """Enqueues a triangular-solve operation onto the computation.""" + return self._client.TriangularSolve( + a, b, left_side, lower, transpose_a, conjugate_a) + + def Gather(self, a, start_indices, dimension_numbers, slice_sizes): + """Enqueues a Gather operation onto the computation.""" + return self._client.Gather(a, start_indices, dimension_numbers, + slice_sizes) + + def Scatter(self, a, scatter_indices, updates, update_computation, + dimension_numbers): + """Enqueues a Scatter operation onto the computation.""" + return self._client.Scatter( + a, scatter_indices, updates, update_computation.computation, + dimension_numbers,) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. @@ -1409,6 +1616,16 @@ def get_replica_count(): return c_api.GetReplicaCount() +def register_cpu_custom_call_target(name, fn): + """Registers a CPU custom call target. + + Args: + name: bytes containing the name of the function. + fn: a PyCapsule object containing the function pointer. + """ + c_api.RegisterCpuCustomCallTarget(name, fn) + + def GetPaddingConfigFromTriples(triples): """Create PaddingConfig proto from list of triples of integers.""" padding_config = xla_data_pb2.PaddingConfig() diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 21b5c93b615ec429a5da0b4ffe89e8f75f59ef1b..c80e792464560f4722b657694d8eb6f5e03956a9 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -18,11 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import itertools import threading import numpy as np +from tensorflow.compiler.xla.python import custom_call_for_test from tensorflow.compiler.xla.python import xla_client import unittest @@ -51,9 +53,11 @@ class LocalComputationTest(unittest.TestCase): def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) - def _ExecuteAndCompareClose(self, c, arguments=(), expected=None): - self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments, - expected) + def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7, + atol=0): + self._ExecuteAndAssertWith( + functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), + c, arguments, expected) def NumpyArrayF32(*args, **kwargs): @@ -143,6 +147,17 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + def testIota(self): + c = self._NewComputation() + c.Iota(np.float32, 10) + self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32)) + + def testBroadcastedIota(self): + c = self._NewComputation() + c.BroadcastedIota(np.int64, (2, 3), 1) + expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64) + self._ExecuteAndCompareExact(c, expected=expected) + def testBooleanAnd(self): c = self._NewComputation() c.And( @@ -268,6 +283,20 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayF64([100, -100, 200, -200]))) self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + def testCustomCall(self): + c = self._NewComputation() + for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): + xla_client.register_cpu_custom_call_target(name, fn) + c.CustomCall( + b"test_subtract_f32", + operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)), + shape_with_layout=xla_client.Shape.array_shape(np.float32, (), ()), + operand_shapes_with_layout=( + xla_client.Shape.array_shape(np.float32, (), ()), + xla_client.Shape.array_shape(np.float32, (), ()), + )) + self._ExecuteAndCompareClose(c, expected=0.75) + class ParametersTest(LocalComputationTest): """Tests focusing on Parameter ops and argument-passing.""" @@ -524,6 +553,18 @@ class SingleOpTest(LocalComputationTest): for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype]) + # TODO(b/123523486): re-enable when shape check is resolved + def DISABLED_testAllToAllOneReplica(self): + samples = [ + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples[:1]: + c = self._NewComputation() + c.AllToAll(c.Constant(lhs), 0, 0) + self._ExecuteAndCompareExact(c, expected=lhs) + def testCrossReplicaSumOneReplica(self): samples = [ NumpyArrayF32(42.0), @@ -536,6 +577,18 @@ class SingleOpTest(LocalComputationTest): c.CrossReplicaSum(c.Constant(lhs)) self._ExecuteAndCompareExact(c, expected=lhs) + def testCrossReplicaSumOneReplicaWithSingletonGroup(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + c.CrossReplicaSum(c.Constant(lhs), [[0]]) + self._ExecuteAndCompareExact(c, expected=lhs) + def testDotMatrixVectorF32(self): c = self._NewComputation() lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) @@ -1057,6 +1110,38 @@ class SingleOpTest(LocalComputationTest): self.assertTrue(np.all(lo <= result)) self.assertTrue(np.all(result < hi)) + def testCholesky(self): + l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], + dtype=np.float32) + c = self._NewComputation() + c.Cholesky(c.Constant(np.dot(l, l.T))) + self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4) + + def testQR(self): + a = np.array( + [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], + dtype=np.float32) + c = self._NewComputation() + c.QR(c.Constant(a), full_matrices=True) + q, r = self._Execute(c, ()) + np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) + + def testTriangularSolve(self): + a_vals = np.array( + [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], + dtype=np.float32) + b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + dtype=np.float32) + + c = self._NewComputation() + c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False, + lower=True, transpose_a=True) + self._ExecuteAndCompareClose(c, expected=np.array([ + [0.5, 0.08333334, 0.04629629, 0.03367003], + [2.5, -0.25, -0.1388889, -0.1010101], + [4.5, -0.58333331, -0.32407406, -0.23569024], + ], dtype=np.float32), rtol=1e-4) + def testIsConstant(self): c = self._NewComputation() a = c.ConstantS32Scalar(3) @@ -1068,6 +1153,21 @@ class SingleOpTest(LocalComputationTest): self.assertFalse(c.IsConstant(non_const_expr)) # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564) + def testGather(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) + dnums = xla_client.xla_data_pb2.GatherDimensionNumbers() + dnums.offset_dims.append(1) + dnums.offset_dims.append(2) + dnums.start_index_map.append(0) + dnums.start_index_map.append(1) + dnums.index_vector_dim = 2 + c = self._NewComputation() + c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1]) + g = self._Execute(c, ()) + expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) + np.testing.assert_allclose(g, expected, rtol=1e-4) + class EmbeddedComputationsTest(LocalComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" @@ -1125,6 +1225,14 @@ class EmbeddedComputationsTest(LocalComputationTest): c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) return c.Build() + def _CreateBinaryAddS32Computation(self): + """Computation (s32, s32) -> s32 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayS32(0)), + c.ParameterFromNumpy(NumpyArrayS32(0))) + return c.Build() + def _CreateBinaryAddF32Computation(self): """Computation (f32, f32) -> f32 that adds its two parameters.""" c = self._NewComputation("add_param0_by_param1") @@ -1507,6 +1615,23 @@ class EmbeddedComputationsTest(LocalComputationTest): execution.join() self.assertEqual(want, got) + def testScatter(self): + a = np.arange(9).astype(np.int32).reshape((3, 3)) + scatter_indices = np.array([0, 2], dtype=np.int32) + updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) + + dnums = xla_client.xla_data_pb2.ScatterDimensionNumbers() + dnums.update_window_dims.append(1) + dnums.inserted_window_dims.append(0) + dnums.scatter_dims_to_operand_dims.append(0) + dnums.index_vector_dim = 1 + + c = self._NewComputation() + c.Scatter(c.Constant(a), c.Constant(scatter_indices), c.Constant(updates), + self._CreateBinaryAddS32Computation(), dnums) + expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32) + self._ExecuteAndCompareClose(c, expected=expected) + class ErrorTest(LocalComputationTest): diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py index 757e41a78ad2b57d2ef6e1f3055160be22c7b3ed..19bd685ab2260485d2a86f0a682d0cdd36712fdb 100644 --- a/tensorflow/compiler/xla/python_api/xla_literal.py +++ b/tensorflow/compiler/xla/python_api/xla_literal.py @@ -69,7 +69,7 @@ def _ConvertNumpyArrayToLiteral(ndarray): if ndarray.ndim == 0: getattr(literal, type_record.literal_field_name).append( - _np.asscalar(ndarray.astype(type_record.literal_field_type))) + ndarray.astype(type_record.literal_field_type).item()) else: # Ndarrays with boolean dtypes need special type conversion with protobufs if ndarray.dtype in {_np.bool_, _np.dtype('bool')}: diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py index 95b2bf300ec67e9f034f77450416544cb088ae55..bdcd4abd6cc708795416b15412f37dde10d7fe97 100644 --- a/tensorflow/compiler/xla/python_api/xla_shape.py +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -20,6 +20,8 @@ from __future__ import print_function import numpy as _np # Avoids becoming a part of public Tensorflow API. +from six.moves import xrange + from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import types diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index ceb5e74db7c3b9305e9d77068df9ae0a3690af8a..08b78ee244844f41d551d7e249cec0cbf157d639 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -32,48 +32,19 @@ limitations under the License. namespace xla { -namespace { - -template -std::unique_ptr> MatmulArray2DImpl( - const Array2D& lhs, const Array2D& rhs, - const std::function& impl_fn) { - CHECK_EQ(lhs.width(), rhs.height()); - int m = lhs.height(); - int n = rhs.width(); - int k = lhs.width(); - auto result = absl::make_unique>(m, n); - // Because Eigen is a header-oriented library, make sure that the Eigen code - // is the same as the code used by the CPU backend (otherwise the linker will - // randomly pick *some* definition). - impl_fn( - /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, - k, - /*transpose_lhs=*/0, - /*transpose_rhs=*/0); - return result; -} - -} // namespace - /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - return MatmulArray2DImpl( - lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); + return HloEvaluator::MatmulArray2D(lhs, rhs); } /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( @@ -557,10 +528,11 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( dim2.set_base_dilation(lhs_dilation.second); *window.add_dimensions() = dim2; - const Shape& shape = ShapeInference::InferConvolveShape( - lhs_literal.shape(), rhs_literal.shape(), - /*feature_group_count=*/1, window, dnums) - .ConsumeValueOrDie(); + const Shape& shape = + ShapeInference::InferConvolveShape( + lhs_literal.shape(), rhs_literal.shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums) + .ConsumeValueOrDie(); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -572,16 +544,16 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( /*new_size=*/2, PrecisionConfig::DEFAULT); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, precision_config)); + /*batch_group_count=*/1, window, dnums, precision_config)); HloModuleConfig config; HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; Literal result_literal = - evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); + evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); - CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4); + CHECK_EQ(result_literal.shape().rank(), 4); auto result = absl::make_unique>(result_literal.shape().dimensions(0), result_literal.shape().dimensions(1), @@ -634,24 +606,26 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& reduce_function) { std::vector result; CHECK_EQ(dims.size(), 3); - const std::set dim_set(dims.begin(), dims.end()); + const absl::flat_hash_set dim_set(dims.begin(), dims.end()); CHECK_EQ(dim_set.size(), 3); - for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) { - for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2()); + for (int64 a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1()); + ++a0) { + for (int64 a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2()); ++a1) { - for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3()); + for (int64 a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3()); ++a2) { - for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4()); + for (int64 a3 = 0; a3 == 0 || (!dim_set.contains(3) && a3 < array.n4()); ++a3) { float accumulator = init; - for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1()); - ++i0) { - for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2()); - ++i1) { + for (int64 i0 = 0; + i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) { + for (int64 i1 = 0; + i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) { for (int64 i2 = 0; - i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { + i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) { for (int64 i3 = 0; - i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { + i3 == 0 || (dim_set.contains(3) && i3 < array.n4()); + ++i3) { // Handle zero-sized arrays. if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 && array.n4() > 0) { diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index d8123a6de28ca532819ece4a75cd0b725f8c1bbd..22b4218fbd5e9bc59a0de22735eb51db46670f09 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -47,6 +47,14 @@ namespace xla { }); } +::grpc::Status GRPCService::GetDeviceHandles(::grpc::ServerContext* context, + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->GetDeviceHandles(arg, result); + }); +} + ::grpc::Status GRPCService::Compile(::grpc::ServerContext* /*context*/, const CompileRequest* arg, CompileResponse* result) { @@ -61,6 +69,14 @@ namespace xla { [this, arg, result]() { return service_->Execute(arg, result); }); } +::grpc::Status GRPCService::ExecuteGraphParallel( + ::grpc::ServerContext* /*context*/, const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->ExecuteGraphParallel(arg, result); + }); +} + ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index 3e586b288a56a22573d0c3b9ae7b2f25fdbf851a..b546704f73e34941cbf7bc2fe08062aa438039f7 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -39,6 +39,10 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; + ::grpc::Status GetDeviceHandles(::grpc::ServerContext* context, + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; + ::grpc::Status Compile(::grpc::ServerContext* context, const CompileRequest* arg, CompileResponse* result) override; @@ -46,6 +50,9 @@ class GRPCService : public grpc::XlaService::Service { ::grpc::Status Execute(::grpc::ServerContext* context, const ExecuteRequest* arg, ExecuteResponse* result) override; + ::grpc::Status ExecuteGraphParallel(::grpc::ServerContext* context, + const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) override; ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 429b4e490cc2f1ab894924e95db3ad7e80342a72..4f6509c1cb9dddac3f90cb8bea9b8ee989e4da4b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1,6 +1,14 @@ # Description: # XLA service implementation. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") + licenses(["notice"]) # Apache 2.0 package(default_visibility = [":friends"]) @@ -12,15 +20,6 @@ package_group( ], ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_proto_library_py", -) - xla_proto_library( name = "hlo_proto", srcs = ["hlo.proto"], @@ -224,23 +223,28 @@ cc_library( "hlo_evaluator_typed_visitor.h", "hlo_evaluator_typed_visitor_bfloat16.cc", "hlo_evaluator_typed_visitor_bool.cc", + "hlo_evaluator_typed_visitor_complex128.cc", "hlo_evaluator_typed_visitor_complex64.cc", "hlo_evaluator_typed_visitor_double.cc", "hlo_evaluator_typed_visitor_float.cc", "hlo_evaluator_typed_visitor_half.cc", + "hlo_evaluator_typed_visitor_int16.cc", "hlo_evaluator_typed_visitor_int32.cc", "hlo_evaluator_typed_visitor_int64.cc", "hlo_evaluator_typed_visitor_int8.cc", + "hlo_evaluator_typed_visitor_uint16.cc", "hlo_evaluator_typed_visitor_uint32.cc", "hlo_evaluator_typed_visitor_uint64.cc", "hlo_evaluator_typed_visitor_uint8.cc", ], hdrs = ["hlo_evaluator.h"], deps = [ + ":dynamic_dimension_inference", ":hlo", ":hlo_casting_utils", ":hlo_query", ":shape_inference", + "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -249,12 +253,14 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", @@ -266,6 +272,7 @@ tf_cc_test( srcs = ["hlo_evaluator_test.cc"], deps = [ ":hlo", + ":hlo_element_type_converter", ":hlo_evaluator", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", @@ -278,13 +285,14 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -512,6 +520,7 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -674,6 +683,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -692,6 +702,7 @@ cc_library( ":compiler", ":computation_layout", ":device_memory_allocator", + ":dynamic_dimension_inference", ":executable", ":execution_tracker", ":hlo", @@ -999,6 +1010,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1010,6 +1022,7 @@ cc_library( srcs = ["name_uniquer.cc"], hdrs = ["name_uniquer.h"], deps = [ + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", @@ -1049,7 +1062,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1087,7 +1099,6 @@ cc_library( ":buffer_value_containers", ":heap_simulator", ":hlo", - ":hlo_memory_scheduler", ":hlo_proto", ":logical_buffer", ":tuple_points_to_analysis", @@ -1132,6 +1143,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1226,7 +1238,6 @@ cc_library( deps = [ ":hlo", ":hlo_proto", - "//tensorflow/compiler/xla:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -1410,6 +1421,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1493,7 +1505,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -1574,6 +1585,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -1589,7 +1603,10 @@ tf_cc_test( ":hlo", ":hlo_casting_utils", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1690,9 +1707,9 @@ tf_cc_test( ) cc_library( - name = "convolution_feature_group_converter", - srcs = ["convolution_feature_group_converter.cc"], - hdrs = ["convolution_feature_group_converter.h"], + name = "convolution_group_converter", + srcs = ["convolution_group_converter.cc"], + hdrs = ["convolution_group_converter.h"], deps = [ ":hlo", ":hlo_pass", @@ -1710,11 +1727,11 @@ cc_library( ) tf_cc_test( - name = "convolution_feature_group_converter_test", + name = "convolution_group_converter_test", size = "small", - srcs = ["convolution_feature_group_converter_test.cc"], + srcs = ["convolution_group_converter_test.cc"], deps = [ - ":convolution_feature_group_converter", + ":convolution_group_converter", ":hlo", ":hlo_matchers", ":hlo_parser", @@ -1777,6 +1794,7 @@ tf_cc_test( ":hlo_cse", ":hlo_dce", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", ":hlo_pass_pipeline", ":tuple_simplifier", @@ -1855,8 +1873,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -1905,6 +1924,82 @@ cc_library( ], ) +cc_library( + name = "dynamic_dimension_inference", + srcs = ["dynamic_dimension_inference.cc"], + hdrs = ["dynamic_dimension_inference.h"], + deps = [ + ":hlo", + ":while_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "dynamic_padder", + srcs = ["dynamic_padder.cc"], + hdrs = ["dynamic_padder.h"], + deps = [ + ":dynamic_dimension_inference", + ":hlo_dce", + ":hlo_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "dynamic_padder_test", + srcs = ["dynamic_padder_test.cc"], + deps = [ + ":dynamic_padder", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + +tf_cc_test( + name = "dynamic_dimension_inference_test", + srcs = ["dynamic_dimension_inference_test.cc"], + deps = [ + ":dynamic_dimension_inference", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "reshape_mover_test", srcs = ["reshape_mover_test.cc"], @@ -1971,7 +2066,6 @@ cc_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -2012,6 +2106,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", @@ -2062,13 +2157,16 @@ tf_cc_test( srcs = ["hlo_computation_test.cc"], deps = [ ":hlo", - ":hlo_matchers", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2241,6 +2339,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -2501,6 +2600,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2545,6 +2645,7 @@ tf_cc_test( srcs = ["hlo_verifier_test.cc"], deps = [ ":hlo", + ":hlo_module_config", ":hlo_parser", ":hlo_verifier", ":layout_assignment", @@ -2552,6 +2653,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -2656,7 +2758,6 @@ tf_cc_test( ":algebraic_simplifier", ":computation_layout", ":hlo", - ":hlo_matchers", ":layout_assignment", ":pattern_matcher", ":pattern_matcher_gmock", @@ -2670,6 +2771,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/types:span", @@ -2922,13 +3024,11 @@ cc_library( srcs = ["hlo_get_dimension_size_rewriter.cc"], hdrs = ["hlo_get_dimension_size_rewriter.h"], deps = [ + ":dynamic_dimension_inference", ":hlo", ":hlo_pass", ":shape_inference", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", @@ -3122,6 +3222,7 @@ cc_library( name = "hlo_graph_dumper", srcs = [ "hlo_graph_dumper.cc", + "hlo_graph_html_renderer.cc", ], hdrs = ["hlo_graph_dumper.h"], deps = [ @@ -3129,6 +3230,7 @@ cc_library( ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", + ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -3137,6 +3239,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", @@ -3297,7 +3400,6 @@ cc_library( ":hlo_pass_pipeline", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -3354,10 +3456,39 @@ cc_library( ":hlo_profile_printer_data", ":human_readable_profile_builder", "//tensorflow/compiler/xla:types", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], ) +cc_library( + name = "sort_simplifier", + srcs = ["sort_simplifier.cc"], + hdrs = ["sort_simplifier.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "sort_simplifier_test", + srcs = ["sort_simplifier_test.cc"], + deps = [ + ":hlo_matchers", + ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", + ":sort_simplifier", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + cc_library( name = "tuple_util", srcs = ["tuple_util.cc"], @@ -3456,7 +3587,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3525,14 +3655,16 @@ cc_library( tf_cc_test( name = "indexed_array_analysis_test", srcs = ["indexed_array_analysis_test.cc"], + extra_copts = ["-Wno-string-plus-int"], deps = [ ":hlo_matchers", + ":hlo_parser", ":indexed_array_analysis", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -3554,6 +3686,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:variant", ], ) @@ -3582,7 +3715,6 @@ cc_library( srcs = ["hlo_lexer.cc"], hdrs = [ "hlo_lexer.h", - "hlo_token.h", ], deps = [ "//tensorflow/compiler/xla:shape_util", @@ -3618,6 +3750,47 @@ cc_library( ], ) +cc_library( + name = "optimize_input_output_buffer_alias", + srcs = ["optimize_input_output_buffer_alias.cc"], + hdrs = ["optimize_input_output_buffer_alias.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + ], +) + +tf_cc_test( + name = "optimize_input_output_buffer_alias_test", + srcs = ["optimize_input_output_buffer_alias_test.cc"], + deps = [ + ":optimize_input_output_buffer_alias", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + ], +) + cc_library( name = "ar_crs_combiner", srcs = ["ar_crs_combiner.cc"], @@ -3627,6 +3800,7 @@ cc_library( ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -3638,6 +3812,38 @@ cc_library( ], ) +cc_library( + name = "dynamic_index_splitter", + srcs = ["dynamic_index_splitter.cc"], + hdrs = ["dynamic_index_splitter.h"], + deps = [ + ":hlo_casting_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "dynamic_index_splitter_test", + srcs = ["dynamic_index_splitter_test.cc"], + deps = [ + ":dynamic_index_splitter", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "ar_crs_combiner_test", srcs = ["ar_crs_combiner_test.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index a348bcf0a232994a046df51563a9167faac08190..cd06cfcdd38d56a43def8a531fb7f018b22ed888 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include +#include +#include #include #include #include @@ -24,6 +26,9 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -31,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -40,12 +46,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" @@ -68,6 +76,45 @@ bool IsAll(const HloInstruction* op, int8 value) { } } +// Checks whether `op` is a floating-point constant or broadcast of a constant +// of the form +/- 2^k for some integer k positive, negative, or zero. Such +// values are interesting because multiplying by a power of 2 just moves the +// exponent. +bool IsAllFpConstantPowerOf2(const HloInstruction* op) { + // Unwrap the broadcast if necessary. + const HloInstruction* c; + if (!Match(op, m::ConstantEffectiveScalar(&c)) && + !Match(op, m::Broadcast(m::Constant(&c).WithShape( + m::Shape().IsEffectiveScalar())))) { + return false; + } + auto val = [&]() -> absl::optional { + switch (c->shape().element_type()) { + case BF16: + return static_cast(c->literal().GetFirstElement()); + case F16: + return static_cast(c->literal().GetFirstElement()); + case F32: + return c->literal().GetFirstElement(); + case F64: + return c->literal().GetFirstElement(); + default: + // Cowardly refuse to consider complex types. + return absl::nullopt; + } + }(); + if (!val) { + return false; + } + + int exp; + double mantissa = std::frexp(*val, &exp); + // frexp returns a value in the range (-1; -0.5] U [0.5, 1). A return value + // of +/-0.5 therefore indicates that the floating point value is a power of + // 2. + return mantissa == 0.5 || mantissa == -0.5; +} + // Returns whether the given transpose produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. bool TransposeIsBitcast(const HloInstruction* transpose) { @@ -77,23 +124,37 @@ bool TransposeIsBitcast(const HloInstruction* transpose) { transpose->dimensions()); } -// Returns true if the given reshape/copy produces a result which is bit-wise -// identical to its operand and thus may be replaced with a bitcast. -// -// This function is conservative -- even if this function returns false, the -// reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. -bool ReshapeOrCopyIsBitcast( - const HloInstruction* instr, - const AlgebraicSimplifierOptions::ValidBitcastCallback& - valid_bitcast_callback) { +// Recursive helper for method below. +HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper( + HloInstruction* instr, HloInstruction* operand, + const AlgebraicSimplifierOptions& options) { + // Can't replace chain of copies and reshapes with bitcasts if the compiler + // used a memory layout which isn't compatible. + if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) { + return operand; + } + + // If the operand is a copy or reshape try to see if the operand's operand + // would produce a bitcast with initial instruction. + if (HloOpcode::kReshape == operand->opcode() || + HloOpcode::kCopy == operand->opcode()) { + return BitcastingOperandOfReshapeOrCopyChainHelper( + instr, operand->mutable_operand(0), options); + } + return nullptr; +} + +// Returns an operand of a chain of reshapes and copies that is bit-wise +// identical to first reshape or copy in the chain. +HloInstruction* BitcastingOperandOfReshapeOrCopyChain( + HloInstruction* instr, const AlgebraicSimplifierOptions& options) { + if (!options.is_layout_sensitive()) { + return nullptr; + } CHECK(HloOpcode::kReshape == instr->opcode() || HloOpcode::kCopy == instr->opcode()); - - const HloInstruction* operand = instr->operand(0); - // Can't insert bitcasts if the compiler used a memory layout which isn't - // compatible. - return ShapeUtil::ReshapeIsBitcast(operand->shape(), instr->shape()) && - valid_bitcast_callback(operand->shape(), instr->shape()); + return BitcastingOperandOfReshapeOrCopyChainHelper( + instr, instr->mutable_operand(0), options); } bool IsUnstridedSlice(const HloInstruction* hlo) { @@ -160,6 +221,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandlePower(HloInstruction* power) override; + Status HandleRemainder(HloInstruction* remainder) override; + Status HandleReshape(HloInstruction* reshape) override; Status HandleReduce(HloInstruction* reduce) override; @@ -199,9 +262,16 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // more fusion than leaving the nodes as Dot operations. StatusOr HandleDotStrengthReduction(HloInstruction* dot); + // Removes dimension dim from hlo. + HloInstruction* StripDim(HloInstruction* hlo, int64 dim) { + CHECK_EQ(hlo->shape().dimensions(dim), 1); + return computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::DeleteDimension(dim, hlo->shape()), hlo)); + } + // Reshapes an instruction to rank 1 if it is not already rank 1. HloInstruction* Flatten(HloInstruction* hlo) { - if (ShapeUtil::Rank(hlo->shape()) == 1) { + if (hlo->shape().rank() == 1) { return hlo; } return computation_->AddInstruction(HloInstruction::CreateReshape( @@ -221,8 +291,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { shape, hlo, zero, {dim}, AddReduce_computation)); } - // Convenience method for replacing an instruction with a bitcast. - void ReplaceWithBitcast(HloInstruction* instruction); + // Convenience method for replacing an instruction with a bitcast. If operand + // is not null, then the bitcast will use the specified operand instead of the + // operand of the instruction. + void ReplaceWithBitcast(HloInstruction* instruction, + HloInstruction* operand = nullptr); // Replace old instruction with new instruction if old and new instructions // have the same shape. Updates uses and root instruction. Returns whether a @@ -351,17 +424,19 @@ bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, } } -void AlgebraicSimplifierVisitor::ReplaceWithBitcast( - HloInstruction* instruction) { +void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction, + HloInstruction* operand) { CHECK_EQ(1, instruction->operand_count()); + if (operand == nullptr) { + operand = instruction->mutable_operand(0); + } CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()), - ShapeUtil::ElementsIn(instruction->operand(0)->shape())); + ShapeUtil::ElementsIn(operand->shape())); CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()), - ShapeUtil::ByteSizeOf(instruction->operand(0)->shape())); + ShapeUtil::ByteSizeOf(operand->shape())); - auto bitcast = computation_->AddInstruction( - HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast, - instruction->mutable_operand(0))); + auto bitcast = computation_->AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kBitcast, operand)); TF_CHECK_OK(ReplaceInstruction(instruction, bitcast)); } @@ -415,6 +490,40 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { sum_of_constants)); } + // A*C + B*C => (A+B)*C + // + // - If A, B, and C are integers, do this unconditionally. Proof of + // correctness: https://rise4fun.com/Alive/u9X. + // + // - If A, B, and C are floating point, do this if C is a scalar constant or + // broadcast of scalar constant and is equal to +/- 2^k for some (possibly + // negative) integer k. + // + // Multiplying by a power of 2 just moves the exponent, so our answer is + // exact modulo rounding of intermediate results so long as + // + // - none of the three products has an exponent which underflows (so the + // result is 0 or denormal), and + // - none of the three products overflows to inf. + // + // Proof: See algebraic_simplifier_proof_distributive_property.py. + // + // We deem these differences in rounding, underflow, and overflow + // acceptable in the ML context. + HloInstruction *b, *c; + if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) || + (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && + (ShapeUtil::ElementIsIntegral(add->shape()) || + IsAllFpConstantPowerOf2(c))) { + return ReplaceWithNewInstruction( + add, HloInstruction::CreateBinary( + add->shape(), HloOpcode::kMultiply, + computation_->AddInstruction(HloInstruction::CreateBinary( + add->shape(), HloOpcode::kAdd, a, b)), + c)); + } return Status::OK(); } @@ -488,9 +597,9 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { return Status::OK(); } - if (options_.is_layout_sensitive() && - ReshapeOrCopyIsBitcast(copy, options_.valid_bitcast_callback())) { - ReplaceWithBitcast(copy); + if (HloInstruction* bitcast_operand = + BitcastingOperandOfReshapeOrCopyChain(copy, options_)) { + ReplaceWithBitcast(copy, bitcast_operand); } return Status::OK(); @@ -603,7 +712,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( return Status::OK(); } PaddingConfig padding_config; - for (int64 dim = 0; dim < ShapeUtil::Rank(operands[0]->shape()); ++dim) { + for (int64 dim = 0; dim < operands[0]->shape().rank(); ++dim) { auto padding_config_dim = padding_config.add_dimensions(); padding_config_dim->set_edge_padding_high(0); padding_config_dim->set_edge_padding_low(0); @@ -631,7 +740,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( static HloInstruction* BuildTupleConstant(HloComputation* computation, const LiteralSlice& literal) { - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { @@ -648,7 +757,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // Tuple constants aren't directly supported by any backend. Expand them into // explicit Tuple instructions. - if (ShapeUtil::IsTuple(constant->shape())) { + if (constant->shape().IsTuple()) { return ReplaceInstruction( constant, BuildTupleConstant(computation_, constant->literal())); } @@ -670,7 +779,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { } // If a literal is an increasing sequence from zero, replace it with an iota. - if (ShapeUtil::Rank(constant->shape()) == 1 && + if (constant->shape().rank() == 1 && ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsR1Iota()) { return ReplaceWithNewInstruction( @@ -707,6 +816,79 @@ Status InvertConstant(const HloInstruction& constant, Literal* result) { return T{1.0} / constant.literal().Get(indices); }); } + +template +std::unique_ptr TryDivideToShift(HloInstruction* divide, + HloComputation* computation) { + HloInstruction *a, *b, *c; + CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); + + if (ShapeUtil::ElementIsIntegral(divide->shape()) && + !Match(b, m::ConstantEffectiveScalar(&c)) && + !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { + return nullptr; + } + + if (ShapeUtil::ElementIsSigned(divide->shape())) { + int64 b_value = c->literal().GetFirstElement(); + if (b_value > 0 && IsPowerOfTwo(static_cast(b_value))) { + // Handle negative dividends by negating the result of the division. + HloInstruction* zero_like_a = BroadcastZeros( + computation, a->shape().element_type(), a->shape().dimensions()); + + auto* dividend_is_negative = + computation->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a, + zero_like_a)); + + auto* negated_dividend = computation->AddInstruction( + HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); + + auto* abs_dividend = + computation->AddInstruction(HloInstruction::CreateTernary( + a->shape(), HloOpcode::kSelect, dividend_is_negative, + negated_dividend, a)); + + int log2_abs_b_value = tensorflow::Log2Floor64(b_value); + + auto* shift_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(log2_abs_b_value))); + if (!ShapeUtil::IsScalar(b->shape())) { + shift_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), shift_amount, {})); + } + + auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend, + shift_amount)); + + auto* neqated_quotient = + computation->AddInstruction(HloInstruction::CreateUnary( + quotient->shape(), HloOpcode::kNegate, quotient)); + + return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect, + dividend_is_negative, + neqated_quotient, quotient); + } + } else { + uint64 b_value = c->literal().GetFirstElement(); + if (IsPowerOfTwo(b_value)) { + int log2_abs_b_value = tensorflow::Log2Floor64(b_value); + HloInstruction* shift_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(log2_abs_b_value))); + if (!ShapeUtil::IsScalar(b->shape())) { + shift_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), shift_amount, {})); + } + return HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kShiftRightLogical, a, shift_amount); + } + } + + return nullptr; +} } // namespace Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { @@ -719,6 +901,60 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } + // A / B => A >> log2(B) if B is a power of 2. + switch (divide->shape().element_type()) { + case S8: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S16: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S32: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case S64: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U8: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U16: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U32: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + case U64: + if (std::unique_ptr shift = + TryDivideToShift(divide, computation_)) { + return ReplaceWithNewInstruction(divide, std::move(shift)); + } + break; + default: + break; + } + // exp(A)/exp(B) => exp(A-B) if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b))) .WithShape(m::Shape(&shape)))) { @@ -786,6 +1022,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { case C64: TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); break; + case C128: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; default: return Status::OK(); } @@ -834,21 +1073,51 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - int64 lhs_collapsing_dim = - dot->dot_dimension_numbers().lhs_contracting_dimensions(0); + + const auto kept_dim = [](int64 rank, int64 contracting_dimension, + absl::Span batch_dimensions) -> int64 { + for (int64 i = 0; i < rank; ++i) { + if (i != contracting_dimension && + !absl::c_linear_search(batch_dimensions, i)) { + return i; + } + } + return -1; + }; + + const int64 dot_rank = dot->shape().rank(); + const int64 rhs_rank = rhs->shape().rank(); + const int64 lhs_rank = lhs->shape().rank(); + const auto& dnums = dot->dot_dimension_numbers(); + if (dnums.rhs_contracting_dimensions_size() > 1) { + return false; + } + if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) { + return false; + } + int64 lhs_collapsing_dim = dnums.lhs_contracting_dimensions(0); + int64 lhs_kept_dim = kept_dim(lhs_rank, lhs_collapsing_dim, + AsInt64Slice(dnums.lhs_batch_dimensions())); + // If there is no non-contracting dimension in rank 2, do not strength reduce. + if (lhs_kept_dim == -1 && lhs_rank > 1) { + return false; + } if (lhs->IsRank2Transpose()) { lhs = lhs->mutable_operand(0); - lhs_collapsing_dim = 1 - lhs_collapsing_dim; + std::swap(lhs_collapsing_dim, lhs_kept_dim); } - const int64 lhs_kept_dim = 1 - lhs_collapsing_dim; - int64 rhs_collapsing_dim = - dot->dot_dimension_numbers().rhs_contracting_dimensions(0); + int64 rhs_collapsing_dim = dnums.rhs_contracting_dimensions(0); + int64 rhs_kept_dim = kept_dim(rhs_rank, rhs_collapsing_dim, + AsInt64Slice(dnums.rhs_batch_dimensions())); + // If there is no non-contracting dimension in rank 2, do not strength reduce. + if (rhs_kept_dim == -1 && rhs_rank > 1) { + return false; + } if (rhs->IsRank2Transpose()) { rhs = rhs->mutable_operand(0); - rhs_collapsing_dim = 1 - rhs_collapsing_dim; + std::swap(rhs_collapsing_dim, rhs_kept_dim); } - const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { if (hlo->shape().element_type() == element_type) { @@ -871,10 +1140,15 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( return AddReduce(as_type(hlo, F32), dim); }; + auto broadcast = [&](HloInstruction* hlo, const Shape& shape, + absl::Span dims) { + return computation_->AddInstruction( + HloInstruction::CreateBroadcast(shape, hlo, dims)); + }; + auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, int64 dim) { - return computation_->AddInstruction( - HloInstruction::CreateBroadcast(shape, hlo, {dim})); + return broadcast(hlo, shape, {dim}); }; auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { @@ -885,11 +1159,9 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // Strength reduce dot(a[K] , b[K]) = // reshape(result.shape, // reduce_sum(multiply(a, b), {0})) - if (ShapeUtil::Rank(rhs->shape()) == 1 && - ShapeUtil::Rank(lhs->shape()) == 1) { - TF_RETURN_IF_ERROR( - ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( - multiply(Flatten(lhs), Flatten(rhs)), 0)))); + if (rhs_rank == 1 && lhs_rank == 1) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, rhs), 0)))); return true; } @@ -903,8 +1175,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // Simplify outer product into multiply with implicit broadcasting. // // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) - if (ShapeUtil::Rank(rhs->shape()) == 2 && - rhs->shape().dimensions(rhs_collapsing_dim) == 1) { + if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) { TF_RETURN_IF_ERROR(ReplaceInstruction( dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); @@ -918,10 +1189,9 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // {0}) // ) // ) - if (ShapeUtil::Rank(lhs->shape()) == 1 || - (ShapeUtil::Rank(lhs->shape()) == 2 && - lhs->shape().dimensions(lhs_kept_dim) == 1)) { - if (ShapeUtil::Rank(rhs->shape()) == 1) { + if (lhs_rank == 1 || + (lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) { + if (rhs->shape().rank() == 1) { TF_RETURN_IF_ERROR( ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( multiply(Flatten(lhs), rhs), 0)))); @@ -940,9 +1210,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( // reshape(result.shape, // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) // ) - if (ShapeUtil::Rank(rhs->shape()) == 1 || - (ShapeUtil::Rank(rhs->shape()) == 2 && - rhs->shape().dimensions(rhs_kept_dim) == 1)) { + if (rhs_rank == 1 || + (rhs_rank == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) { TF_RETURN_IF_ERROR(ReplaceInstruction( dot, reshape_if_necessary(add_reduce_in_f32( multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), @@ -950,6 +1219,97 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( lhs_collapsing_dim)))); return true; } + + // Only consider kDot with batch dimension. + if (dot_rank <= 2) { + return false; + } + + CHECK_EQ(rhs_rank, lhs_rank); + CHECK_EQ(dot_rank, lhs_rank); + // If there is more than one non-contracting dimension or the batch dimensions + // are not equal, bail out since transposes may be required to do a strength + // reduction. + if (dnums.rhs_batch_dimensions_size() + 2 != dot_rank || + !absl::c_equal(dnums.lhs_batch_dimensions(), + dnums.rhs_batch_dimensions())) { + return false; + } + + auto broadcast_dims = [](int64 rank, int64 non_broadcast_dim) { + absl::InlinedVector dims; + for (int64 i = 0; i < rank; ++i) { + if (i != non_broadcast_dim) { + dims.push_back(i); + } + } + return dims; + }; + + // If the contracting dimension is 1, remove the degnerate dimnesions from the + // lhs and rhs, broadcast each to the result shape and multiply. + if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && + (rhs_kept_dim == rhs_rank - 1 || + (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { + CHECK_EQ(rhs->shape().dimensions(rhs_collapsing_dim), 1); + const int64 lhs_kept_dim_in_output = + lhs_kept_dim > lhs_collapsing_dim ? (lhs_kept_dim - 1) : lhs_kept_dim; + absl::InlinedVector lhs_broadcast_dims; + for (const int64 dim : dnums.lhs_batch_dimensions()) { + lhs_broadcast_dims.push_back(dim > lhs_collapsing_dim ? (dim - 1) : dim); + } + absl::InlinedVector rhs_broadcast_dims = lhs_broadcast_dims; + lhs_broadcast_dims.push_back(lhs_kept_dim_in_output); + absl::c_sort(lhs_broadcast_dims); + rhs_broadcast_dims.push_back(dot_rank - 1); + absl::c_sort(rhs_broadcast_dims); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary( + multiply(broadcast(StripDim(lhs, lhs_collapsing_dim), + dot->shape(), lhs_broadcast_dims), + broadcast(StripDim(rhs, rhs_collapsing_dim), + dot->shape(), rhs_broadcast_dims))))); + return true; + } + + // If the lhs and rhs non-contracting dimensions are both one, strip each one, + // multiply and then reduce the collapsing dimension + if (lhs->shape().dimensions(lhs_kept_dim) == 1 && + rhs->shape().dimensions(rhs_kept_dim) == 1 && + lhs_kept_dim == rhs_kept_dim) { + auto new_lhs = StripDim(lhs, lhs_kept_dim); + auto new_rhs = StripDim(rhs, rhs_kept_dim); + const int64 reduce_dim = rhs_kept_dim < rhs_collapsing_dim + ? (rhs_collapsing_dim - 1) + : rhs_collapsing_dim; + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32( + multiply(new_lhs, new_rhs), reduce_dim)))); + return true; + } + + // If the lhs non-contracting dimensions is one, strip the one, brodcast to + // the rhs shape, multiply and then reduce the collapsing dimension + if (lhs->shape().dimensions(lhs_kept_dim) == 1) { + auto new_lhs = broadcast(StripDim(lhs, lhs_kept_dim), rhs->shape(), + broadcast_dims(rhs_rank, rhs_kept_dim)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(new_lhs, rhs), + rhs_collapsing_dim)))); + return true; + } + + // If the rhs non-contracting dimensions is one, strip the one, brodcast to + // the lhs shape, multiply and then reduce the collapsing dimension + if (rhs->shape().dimensions(rhs_kept_dim) == 1) { + auto new_rhs = broadcast(StripDim(rhs, rhs_kept_dim), lhs->shape(), + broadcast_dims(lhs_rank, lhs_kept_dim)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, new_rhs), + lhs_collapsing_dim)))); + return true; + } + return false; } @@ -1168,6 +1528,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}. bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice; + HloDynamicSliceInstruction* dynamic_slice = + lhs_is_dynamic_slice ? Cast(lhs) + : Cast(rhs); // ctA: HloInstruction* left_operand = @@ -1185,8 +1548,6 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, dnums, dot->precision_config())); // Get pair {start, 0} or {0, start}. - HloInstruction* original_start_indices = - lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); // Position of start: int index_of_non_zero_start = lhs_is_dynamic_slice ? 1 - lhs_contracting_dimension @@ -1195,23 +1556,19 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( int index_of_zero_start = 1 - index_of_non_zero_start; // Slice out start and 0 components and reorder if necessary. - auto indices_type = original_start_indices->shape().element_type(); + auto indices_type = dynamic_slice->operand(1)->shape().element_type(); Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); HloInstruction* non_zero_start = - computation_->AddInstruction(HloInstruction::CreateSlice( - s_shape, original_start_indices, {index_of_non_zero_start}, - {index_of_non_zero_start + 1}, {1})); + dynamic_slice->mutable_operand(1 + index_of_non_zero_start); HloInstruction* zero_start = - computation_->AddInstruction(HloInstruction::CreateSlice( - s_shape, original_start_indices, {index_of_zero_start}, - {index_of_zero_start + 1}, {1})); - HloInstruction* new_start_indices = - lhs_is_dynamic_slice - ? computation_->AddInstruction(HloInstruction::CreateConcatenate( - d_shape, {non_zero_start, zero_start}, 0)) - : computation_->AddInstruction(HloInstruction::CreateConcatenate( - d_shape, {zero_start, non_zero_start}, 0)); + dynamic_slice->mutable_operand(1 + index_of_zero_start); + std::vector new_start_indices; + if (lhs_is_dynamic_slice) { + new_start_indices = {non_zero_start, zero_start}; + } else { + new_start_indices = {zero_start, non_zero_start}; + } // Build DynamicSlice(ctA x ctB). const int new_slice_m = lhs_is_dynamic_slice ? 1 : m; @@ -1228,25 +1585,31 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are - // rank 2 or below. - if ((dot->shape().element_type() != F32 && - dot->shape().element_type() != BF16) || - ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 || - ShapeUtil::Rank(dot->shape()) > 2) { - return Status::OK(); - } - // Replace a zero element dot with a broadcast of the constant 0. if (ShapeUtil::IsZeroElementArray(dot->shape()) || ShapeUtil::IsZeroElementArray(lhs->shape()) || ShapeUtil::IsZeroElementArray(rhs->shape())) { - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(dot->shape().element_type()))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } + // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are + // rank 2 or below. + if (dot->shape().element_type() != F32 && + dot->shape().element_type() != BF16) { + return Status::OK(); + } + if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || + dot->shape().rank() > 2) { + if (options_.enable_dot_strength_reduction() && + !options_.is_layout_sensitive()) { + TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status()); + } + return Status::OK(); + } + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, OptimizeDotOfConcat(dot)); if (dot_of_concat_optimized) { @@ -1475,7 +1838,7 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, case HloOpcode::kTranspose: return true; case HloOpcode::kSort: - return (!ShapeUtil::IsTuple(instruction->shape())); + return (!instruction->shape().IsTuple()); default: return false; } @@ -1521,8 +1884,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // A degenerate broadcast that has the same input and output rank can be // converted into a transpose. - if (ShapeUtil::Rank(broadcast->shape()) == - ShapeUtil::Rank(operand->shape()) && + if (broadcast->shape().rank() == operand->shape().rank() && ShapeUtil::ElementsIn(broadcast->shape()) == ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " @@ -1677,7 +2039,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (HasInteriorPadding(pad->padding_config())) { PaddingConfig padding_config = pad->padding_config(); bool cleared_interior_padding = false; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { if (padding_config.dimensions(i).interior_padding() > 0 && pad->operand(0)->shape().dimensions(i) == 1) { cleared_interior_padding = true; @@ -1928,14 +2290,151 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( return changed; } +namespace { +template +std::unique_ptr TryRemainderToAnd(HloInstruction* remainder, + HloComputation* computation) { + HloInstruction *a, *b, *c; + CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); + + if (ShapeUtil::ElementIsIntegral(remainder->shape()) && + !Match(b, m::ConstantEffectiveScalar(&c)) && + !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) { + return nullptr; + } + + if (ShapeUtil::ElementIsSigned(remainder->shape())) { + int64 b_value = c->literal().GetFirstElement(); + if (b_value > 0 && IsPowerOfTwo(static_cast(b_value))) { + // Handle negative dividends by negating the result of the division. + HloInstruction* zero_like_a = BroadcastZeros( + computation, a->shape().element_type(), a->shape().dimensions()); + + auto* dividend_is_negative = + computation->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a, + zero_like_a)); + + auto* negated_dividend = computation->AddInstruction( + HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a)); + + auto* abs_dividend = + computation->AddInstruction(HloInstruction::CreateTernary( + a->shape(), HloOpcode::kSelect, dividend_is_negative, + negated_dividend, a)); + + auto* mask_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(b_value - 1))); + if (!ShapeUtil::IsScalar(b->shape())) { + mask_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), mask_amount, {})); + } + + auto* quotient = computation->AddInstruction(HloInstruction::CreateBinary( + remainder->shape(), HloOpcode::kAnd, abs_dividend, mask_amount)); + + auto* neqated_quotient = + computation->AddInstruction(HloInstruction::CreateUnary( + quotient->shape(), HloOpcode::kNegate, quotient)); + + return HloInstruction::CreateTernary( + remainder->shape(), HloOpcode::kSelect, dividend_is_negative, + neqated_quotient, quotient); + } + } else { + uint64 b_value = c->literal().GetFirstElement(); + if (IsPowerOfTwo(b_value)) { + HloInstruction* mask_amount = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(b_value - 1))); + if (!ShapeUtil::IsScalar(b->shape())) { + mask_amount = computation->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), mask_amount, {})); + } + return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd, + a, mask_amount); + } + } + return nullptr; +} +} // namespace + +Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { + HloInstruction *a, *b; + CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); + + // A % B => A & (B - 1) if B is a power of 2. + switch (remainder->shape().element_type()) { + case S8: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S16: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S32: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case S64: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U8: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U16: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U32: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + case U64: + if (std::unique_ptr shift = + TryRemainderToAnd(remainder, computation_)) { + return ReplaceWithNewInstruction(remainder, std::move(shift)); + } + break; + default: + break; + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { auto operand = reshape->mutable_operand(0); // Reshape directly to empty constant if the shape contains zero-element // dimension. if (ShapeUtil::IsZeroElementArray(reshape->shape())) { + // If the instruction doesn't have a layout, use a default layout for + // the literal result. + Shape reshaped_shape = reshape->shape(); + if (!LayoutUtil::HasLayout(reshaped_shape)) { + LayoutUtil::SetToDefaultLayout(&reshaped_shape); + } auto empty_constant = HloInstruction::CreateConstant( - Literal::CreateFromShape(reshape->shape())); + Literal::CreateFromShape(reshaped_shape)); return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); } @@ -1952,6 +2451,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { *operand->mutable_shape() = reshape->shape(); return ReplaceInstruction(reshape, operand); @@ -1983,12 +2483,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } // Make this a bitcast if possible. - if (options_.is_layout_sensitive() && - ReshapeOrCopyIsBitcast(reshape, options_.valid_bitcast_callback())) { - ReplaceWithBitcast(reshape); - return Status::OK(); + if (HloInstruction* bitcast_operand = + BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) { + ReplaceWithBitcast(reshape, bitcast_operand); } - return Status::OK(); } @@ -1998,8 +2496,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { auto dim_is_one = [&](int64 i) -> bool { return reverse->shape().dimensions(i) == 1; }; - if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(), - dim_is_one)) { + if (absl::c_all_of(reverse->dimensions(), dim_is_one)) { return ReplaceInstruction(reverse, reverse->mutable_operand(0)); } return Status::OK(); @@ -2064,7 +2561,7 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { VLOG(10) << "Trying to simplify scalar slice of concat"; // Only do this for R1, there's no chance of this being useful otherwise. - if (ShapeUtil::Rank(slice->shape()) != 1) { + if (slice->shape().rank() != 1) { VLOG(10) << "Not folding, slice is not rank 1"; return false; } @@ -2114,7 +2611,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( return false; } HloInstruction* new_slice_operand = reshape->mutable_operand(0); - int64 slice_rank = ShapeUtil::Rank(slice->shape()); + int64 slice_rank = slice->shape().rank(); std::vector sliced_dims; for (int64 i = 0; i < slice_rank; ++i) { if (slice->slice_starts(i) != 0 || @@ -2126,7 +2623,7 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( if (sliced_dims.size() == 1 && sliced_dims[0] == 0 && slice->slice_starts(0) == 0) { const Shape& new_slice_shape = new_slice_operand->shape(); - const int64 rank = ShapeUtil::Rank(new_slice_shape); + const int64 rank = new_slice_shape.rank(); std::vector new_slice_starts(rank, 0); std::vector new_slice_stides(rank, 1); std::vector new_slice_limits(new_slice_shape.dimensions().begin(), @@ -2223,28 +2720,71 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Most of those optimizations can be done for multi-output - // reduces. - if (ShapeUtil::IsTuple(reduce->shape())) { - return Status::OK(); - } +Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { + HloReduceInstruction* reduce = Cast(hlo); + bool multi_output_reduce = reduce->shape().IsTuple(); + + // For tuple reduce, we require all reduce shapes to be the same, up to the + // element types, so we can just the first operand and the first result as a + // representative. + auto arg = reduce->inputs()[0]; + auto init_value = reduce->init_values()[0]; + const Shape& reduce_result_shape = + multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape(); - auto arg = reduce->mutable_operand(0); - auto init_value = reduce->mutable_operand(1); absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (ShapeUtil::IsZeroElementArray(arg->shape()) || - ShapeUtil::IsZeroElementArray(reduce->shape())) { - return ReplaceWithNewInstruction( - reduce, - HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); + ShapeUtil::IsZeroElementArray(reduce_result_shape)) { + if (multi_output_reduce) { + std::vector broadcast_inits; + int64 inputs = reduce->input_count(); + for (int64 i = 0; i < inputs; ++i) { + broadcast_inits.push_back(computation_->AddInstruction( + HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i), + reduce->init_values()[i], {}))); + } + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateTuple(broadcast_inits)); + } else { + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {})); + } + } + + // If the reduction results in the same number of elements, then the only + // possible side effect would be a reshape. Since the init_value is an + // identity of the reduction function, we can therefore replace the reduce + // with a simple reshape, ignoring the reduction function completely. + if (ShapeUtil::ElementsIn(reduce_result_shape) == + ShapeUtil::ElementsIn(arg->shape())) { + if (multi_output_reduce) { + std::vector reshaped_args; + int64 inputs = reduce->input_count(); + for (int64 i = 0; i < inputs; ++i) { + reshaped_args.push_back( + computation_->AddInstruction(HloInstruction::CreateReshape( + reduce->shape().tuple_shapes(i), reduce->inputs()[i]))); + } + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateTuple(reshaped_args)); + } else { + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReshape(reduce_result_shape, arg)); + } + } + + // TODO(b/112040122): Most of those optimizations below can be done for + // multi-output reduces. + if (multi_output_reduce) { + return Status::OK(); } // A Transpose feeding a reduce can simply permute the reduction dimensions // field if the output of the reduce is a vector or scalar. Higher ranked // result may require a transpose of the output. - if (ShapeUtil::Rank(reduce->shape()) <= 1 && + if (reduce_result_shape.rank() <= 1 && arg->opcode() == HloOpcode::kTranspose) { auto transpose_dimensions = arg->dimensions(); std::vector new_reduce_dimensions; @@ -2253,20 +2793,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { } return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReduce( - reduce->shape(), arg->mutable_operand(0), init_value, + reduce_result_shape, arg->mutable_operand(0), init_value, new_reduce_dimensions, function)); } - // If the reduction results in the same number of elements, then the only - // possible side effect would be a reshape. Since the init_value is an - // identity of the reduction function, we can therefore replace the reduce - // with a simple reshape, ignoring the reduction function completely. - if (ShapeUtil::ElementsIn(reduce->shape()) == - ShapeUtil::ElementsIn(arg->shape())) { - return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); - } - // If a reduce feeds a reduce with the same computation and initial value, // they can be combined into a single reduce. if (arg->opcode() == HloOpcode::kReduce && @@ -2275,9 +2805,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // Create a new reduce with the combined reduction dimensions of both // reduces. std::vector arg_dims = arg->dimensions(); - std::sort(arg_dims.begin(), arg_dims.end()); + absl::c_sort(arg_dims); std::vector reduce_dims = reduce->dimensions(); - std::sort(reduce_dims.begin(), reduce_dims.end()); + absl::c_sort(reduce_dims); // Transform reduce_dims to the same rank as the operand of the operand. for (int64 arg_dim : arg_dims) { for (int64& dim : reduce_dims) { @@ -2292,9 +2822,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), reduce_dims.end(), std::back_inserter(new_dimensions)); return ReplaceWithNewInstruction( - reduce, - HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0), - init_value, new_dimensions, function)); + reduce, HloInstruction::CreateReduce( + reduce_result_shape, arg->mutable_operand(0), init_value, + new_dimensions, function)); } // A reshape that collapses multiple dimensions into a dimension being @@ -2304,8 +2834,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { std::vector> unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(), arg->shape()); - std::vector arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true); - std::vector arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false); + std::vector arg_dim_in_output(arg->shape().rank(), true); + std::vector arg_dim_unmodified(arg->shape().rank(), false); for (auto dim : dimensions) { arg_dim_in_output[dim] = false; } @@ -2323,21 +2853,21 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { } if (can_move_reshape_into_reduce) { changed_ = true; - std::unordered_set dimensions_not_to_reduce; + absl::flat_hash_set dimensions_not_to_reduce; for (auto dim_pair : unmodified_dims) { if (arg_dim_in_output[dim_pair.second]) { dimensions_not_to_reduce.insert(dim_pair.first); } } std::vector new_reduce_dimensions; - for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) { - if (dimensions_not_to_reduce.count(i) == 0) { + for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) { + if (!dimensions_not_to_reduce.contains(i)) { new_reduce_dimensions.push_back(i); } } return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReduce( - reduce->shape(), arg->mutable_operand(0), init_value, + reduce_result_shape, arg->mutable_operand(0), init_value, new_reduce_dimensions, function)); } } @@ -2352,11 +2882,11 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { HloInstruction* old_reduce = nullptr; for (HloInstruction* operand : arg->operands()) { HloInstruction* new_reduce = computation_->AddInstruction( - HloInstruction::CreateReduce(reduce->shape(), operand, init_value, + HloInstruction::CreateReduce(reduce_result_shape, operand, init_value, reduce->dimensions(), function)); if (old_reduce != nullptr) { new_reduce = computation_->AddInstruction(HloInstruction::CreateMap( - reduce->shape(), {old_reduce, new_reduce}, function)); + reduce_result_shape, {old_reduce, new_reduce}, function)); } old_reduce = new_reduce; } @@ -2385,6 +2915,55 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( function)); } + if (options_.enable_window_reduce_to_reduce_replacement()) { + // A reduce window can be expressed as a reduce and a reshape if all + // dimensions either have a window size of one or the entire dimension. If + // there is no stride, dilation, or padding, this is as easy as checking the + // size of the output shape and window dimension. + // + // The reshape is a bitcast since it adds one-sized dimensions. Often these + // ones are immediately removed as well with another reshape. The + // implementation of reduce tends to be slightly more efficient at reducing + // entire dimensions compared to reduce window. + auto effective_reduce_dims = [&] { + if (window_util::HasStride(window) || window_util::HasDilation(window) || + window_util::HasPadding(window)) { + return absl::InlinedVector{}; + } + absl::InlinedVector reduce_dims; + for (int64 i = 0; i < window.dimensions_size(); ++i) { + if (window.dimensions(i).size() == 1) { + continue; + } else if (reduce_window->shape().dimensions(i) == 1) { + reduce_dims.push_back(i); + } else { + return absl::InlinedVector{}; + } + } + return reduce_dims; + }(); + + // If a reduce window can be expressed as a reduce, do so and reshape the + // output. + if (!effective_reduce_dims.empty()) { + Shape reduce_shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { + return !absl::c_linear_search(effective_reduce_dims, dim); + }, + reduce_window->shape()); + HloInstruction* reduce = + computation_->AddInstruction(HloInstruction::CreateReduce( + /*shape=*/reduce_shape, + /*operand=*/operand, + /*init_value=*/reduce_window->mutable_operand(1), + /*dimensions_to_reduce=*/effective_reduce_dims, + /*reduce_computation=*/function)); + return ReplaceWithNewInstruction( + reduce_window, + HloInstruction::CreateReshape(reduce_window->shape(), reduce)); + } + } + // This optimization folds a pad op into reduce_window. HloInstruction* pad; const HloInstruction* convert = nullptr; @@ -2520,7 +3099,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // Carry out the folding of the pad into reduce_window. VLOG(10) << "Folding pad into reduce-window."; Window new_window = window; - const int64 rank = ShapeUtil::Rank(reduce_window->shape()); + const int64 rank = reduce_window->shape().rank(); TF_RET_CHECK(pad_config.dimensions_size() == rank); TF_RET_CHECK(window.dimensions_size() == rank); for (int64 i = 0; i < rank; ++i) { @@ -2569,110 +3148,24 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return ReplaceWithNewInstruction( sort, HloInstruction::CreateTuple(sort->operands())); } - if (!options_.enable_permutation_sort_replacement()) { - return Status::OK(); - } - // Check if we are sorting a permutation. In that case, we know that the keys - // will be sorted to the identity permutation, and we can represent the - // changes to the 'values' parameter as a scatter. - if (sort->operand_count() == 2 && - operand->opcode() == HloOpcode::kGetTupleElement) { - const HloInstruction* other_sort = operand->operand(0); - // Check whether the 'values' parameter is the result of another sort with - // the same sort dimension. - if (other_sort->opcode() == HloOpcode::kSort && - other_sort->operand_count() >= 2 && - other_sort->dimensions(0) == dimension_to_sort && - other_sort->operand(operand->tuple_index())->opcode() == - HloOpcode::kIota) { - auto* iota = - Cast(other_sort->operand(operand->tuple_index())); - // The sort operand needs to be an integral iota, and the iota dimension - // needs to be the dimension that was sorted. - if (iota->iota_dimension() == dimension_to_sort && - ShapeUtil::ElementIsIntegral(iota->shape())) { - // We use the following construction method for a Scatter that applies - // the permutation from 'keys' to the 'values' parameter. - // - Take the "keys" parameter of the second sort and reshape it to have - // another "1" dimension at the end. - // - Concatenate it with iotas of the same extended shape with all - // different iota_dimensions except the dimension_to_sort in the order - // of iota_dimensions/dimension_to_sort, so e.g. with rank 3 and - // dimension_to_sort = 1, we would have concatenate of (iota with - // iota_dimension=0, keys, iota with iota_dimension = 2) - // - Use this as the indices parameter of scatter, and set updates - // of the scatter to be a reshaped 'values' parameter of sort (adding - // 'rank' many 1 dimensions at the end). - int64 rank = ShapeUtil::Rank(operand->shape()); - Shape extended_shape = operand->shape(); - extended_shape.add_dimensions(1); - extended_shape.mutable_layout()->add_minor_to_major(rank); - auto reshaped_permutation = computation_->AddInstruction( - HloInstruction::CreateReshape(extended_shape, operand)); - std::vector concat_operands; - for (int64 i = 0; i < rank; ++i) { - if (i == dimension_to_sort) { - concat_operands.push_back(reshaped_permutation); - } else { - concat_operands.push_back(computation_->AddInstruction( - HloInstruction::CreateIota(extended_shape, i))); - } - } - Shape concat_shape = operand->shape(); - concat_shape.add_dimensions(rank); - concat_shape.mutable_layout()->add_minor_to_major(rank); - auto scatter_indices = - rank > 1 ? computation_->AddInstruction( - HloInstruction::CreateConcatenate( - concat_shape, concat_operands, rank)) - : reshaped_permutation; - - // We don't care about the operand, it will be completely overridden by - // the updates. - auto scatter_operand = computation_->AddInstruction( - HloInstruction::CreateIota(sort->operand(1)->shape(), 0)); - - // Construct the updates operand of scatter. - Shape update_shape = sort->operand(1)->shape(); - for (int64 i = 0; i < rank; ++i) { - update_shape.add_dimensions(1); - update_shape.mutable_layout()->add_minor_to_major(rank + i); - } - auto scatter_updates = - computation_->AddInstruction(HloInstruction::CreateReshape( - update_shape, sort->mutable_operand(1))); - - // Construct the updates computation, which simply replaces the operand - // values with the update values. - HloComputation::Builder b("update_replace_computation"); - Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); - b.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "scalar_lhs")); - auto scalar_rhs = b.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "scalar_rhs")); - auto update_replace_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_rhs)); - - ScatterDimensionNumbers dim_numbers; - dim_numbers.set_index_vector_dim(rank); - for (int64 i = 0; i < rank; ++i) { - dim_numbers.add_update_window_dims(rank + i); - dim_numbers.add_scatter_dims_to_operand_dims(i); - } - auto scatter = - computation_->AddInstruction(HloInstruction::CreateScatter( - sort->operand(1)->shape(), scatter_operand, scatter_indices, - scatter_updates, update_replace_computation, dim_numbers)); - return ReplaceWithNewInstruction( - sort, HloInstruction::CreateTuple( - {computation_->AddInstruction(HloInstruction::CreateIota( - operand->shape(), dimension_to_sort)), - scatter})); - } + return Status::OK(); +} + +namespace { +bool OnlyPermutesMoreThanOneDegenerateDim(const Shape& shape, + absl::Span perm) { + std::vector new_permutation; + int64 degenerate_count = 0; + for (int64 i = 0; i < perm.size(); ++i) { + if (shape.dimensions(i) != 1) { + new_permutation.push_back(perm[i]); + } else { + ++degenerate_count; } } - return Status::OK(); + return degenerate_count > 1 && absl::c_is_sorted(new_permutation); } +} // namespace Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); @@ -2690,6 +3183,15 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { transpose->dimensions()))); } + // Replace transpose with a reshape if more than one degenerate method is + // permuted. + if (OnlyPermutesMoreThanOneDegenerateDim(transpose->shape(), + transpose->dimensions())) { + return ReplaceWithNewInstruction( + transpose, HloInstruction::CreateReshape( + transpose->shape(), transpose->mutable_operand(0))); + } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { *operand->mutable_shape() = transpose->shape(); return ReplaceInstruction(transpose, operand); @@ -2937,15 +3439,6 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( convolution_shape.element_type(), {conv_width, output_channels}); - // We cannot insert bitcasts if the layouts will not be compatible. - // TODO(b/33178038): Consider inserting a transpose if a bitcast would be - // invalid. - if (!options_.valid_bitcast_callback()(input_shape, new_input_shape) || - !options_.valid_bitcast_callback()(filter_shape, new_filter_shape) || - !options_.valid_bitcast_callback()(dot_output_shape, convolution_shape)) { - return false; - } - auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_rhs = add_bitcast(new_filter_shape, rhs); DotDimensionNumbers dot_dimension_numbers; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index d2775b9fafa7e4c625f5d181114e80e7369f9c78..df5a8c2ec141458a95fafb76b1e99e4b04a61b28 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -25,21 +25,25 @@ namespace xla { class AlgebraicSimplifierOptions { public: - // Given shapes 'from_shape' and 'to_shape', determines if it is valid to - // bitcast from 'from_shape' to 'to_shape' after considering platform - // dependent effects on layout like alignment restrictions. Precondition: the - // two shapes have layouts, the same number of elements and - // ShapeUtil::ReshapeIsBitcast returns true. - using ValidBitcastCallback = + AlgebraicSimplifierOptions() {} + // Platform dependent callback to determine if a reshape `from_shape` to + // `to_shape` is a bitcast. + using ReshapeIsBitcastCallback = std::function; - explicit AlgebraicSimplifierOptions( - ValidBitcastCallback valid_bitcast_callback) - : valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} - // If valid_bitcast_callback returns true, then the pass will replace reshapes - // and transposes with bitcasts. - const ValidBitcastCallback& valid_bitcast_callback() const { - return valid_bitcast_callback_; + ReshapeIsBitcastCallback reshape_is_bitcast_callback) + : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)) {} + + // Use the platform specific callback if set. It is not sensible to return + // true here if the options are not layout sensitive. + bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const { + if (!is_layout_sensitive_) { + return false; + } + if (!reshape_is_bitcast_callback_) { + return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape); + } + return reshape_is_bitcast_callback_(from_shape, to_shape); } // If is_layout_sensitive is true, then the simplifier preserves layout during @@ -47,12 +51,14 @@ class AlgebraicSimplifierOptions { void set_is_layout_sensitive(bool is_layout_sensitive) { is_layout_sensitive_ = is_layout_sensitive; } + bool is_layout_sensitive() const { return is_layout_sensitive_; } // Enable dot simplification on platforms where it is profitable. void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { enable_dot_strength_reduction_ = enable_dot_strength_reduction; } + bool enable_dot_strength_reduction() const { return enable_dot_strength_reduction_; } @@ -65,22 +71,24 @@ class AlgebraicSimplifierOptions { return enable_conv_simplification_; } - // If enable_permutation_sort_replacement is true, a sort op that is known to - // sort a permutation will be replaced with a scatter op. - void set_enable_permutation_sort_replacement( - bool enable_permutation_sort_replacement) { - enable_permutation_sort_replacement_ = enable_permutation_sort_replacement; + // If enable_window_reduce_replacement is true, the kReduceWindow instruction + // can be optimized by replacement with simpler operations. + void set_enable_window_reduce_to_reduce_replacement( + bool enable_window_reduce_to_reduce_replacement) { + enable_window_reduce_to_reduce_replacement_ = + enable_window_reduce_to_reduce_replacement; } - bool enable_permutation_sort_replacement() const { - return enable_permutation_sort_replacement_; + + bool enable_window_reduce_to_reduce_replacement() const { + return enable_window_reduce_to_reduce_replacement_; } private: - ValidBitcastCallback valid_bitcast_callback_; + ReshapeIsBitcastCallback reshape_is_bitcast_callback_; bool is_layout_sensitive_{false}; bool enable_dot_strength_reduction_{true}; bool enable_conv_simplification_{true}; - bool enable_permutation_sort_replacement_{false}; + bool enable_window_reduce_to_reduce_replacement_{true}; }; // A pass which performs algebraic simplifications. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py new file mode 100644 index 0000000000000000000000000000000000000000..5da13da041b4ded813876af7ca379025187545ab --- /dev/null +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Proof that transforming (A*C)+(B*C) <=> (A+B)*C is "safe" if C=2^k. + +Specifically, for all floating-point values A, B, and C, if + + - C is equal to +/- 2^k for some (possibly negative) integer k, and + - A, B, C, A*C, B*C, and A+B are not subnormal, zero, or inf, + +then there exists a rounding mode rm in [RTZ, RNE] such that + + (A*C) + (B*C) == (A+B) * C (computed with rounding mode rm). + +Informally, this means that the equivalence holds for powers of 2 C, modulo +flushing to zero or inf, and modulo rounding of intermediate results. + +Requires z3 python bindings; try `pip install z3-solver`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import z3 + +# We do float16 because it lets the solver run much faster. These results +# should generalize to fp32 and fp64, and you can verify this by changing the +# value of FLOAT_TY (and then waiting a while). +FLOAT_TY = z3.Float16 + +a = z3.FP("a", FLOAT_TY()) +b = z3.FP("b", FLOAT_TY()) +c = z3.FP("c", FLOAT_TY()) + +s = z3.Solver() + +# C must be a power of 2, i.e. significand bits must all be 0. +s.add(z3.Extract(FLOAT_TY().sbits() - 1, 0, z3.fpToIEEEBV(c)) == 0) + +for rm in [z3.RTZ(), z3.RNE()]: + z3.set_default_rounding_mode(rm) + before = a * c + b * c + after = (a + b) * c + + # Check that before == after, allowing that 0 == -0. + s.add( + z3.Not( + z3.Or( + before == after, # + z3.And(z3.fpIsZero(before), z3.fpIsZero(after))))) + + for x in [ + (a * c), + (b * c), + (a + b), + ]: + s.add(z3.Not(z3.fpIsSubnormal(x))) + s.add(z3.Not(z3.fpIsZero(x))) + s.add(z3.Not(z3.fpIsInf(x))) + +if s.check() == z3.sat: + m = s.model() + print("Counterexample found!") + print(m) + print("a*c: ", z3.simplify(m[a] * m[c])) + print("b*c: ", z3.simplify(m[b] * m[c])) + print("a+b: ", z3.simplify(m[a] + m[b])) + print("a*c + b*c: ", z3.simplify(m[a] * m[c] + m[b] * m[c])) + print("(a+b) * c: ", z3.simplify((m[a] + m[b]) * m[c])) +else: + print("Proved!") diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 48f689c96a98065498818aa081d4a5a911aea5a6..f55a1886b8f86af4893c8a7fc18ed935d223eca0 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -29,7 +29,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -42,20 +45,12 @@ namespace xla { namespace { using ::testing::ElementsAre; - +namespace m = match; namespace op = xla::testing::opcode_matchers; -AlgebraicSimplifierOptions::ValidBitcastCallback bitcasting_callback() { - return [](const Shape&, const Shape&) { return true; }; -} - -AlgebraicSimplifierOptions::ValidBitcastCallback non_bitcasting_callback() { - return [](const Shape&, const Shape&) { return false; }; -} - class AlgebraicSimplifierTest : public HloTestBase { protected: - AlgebraicSimplifierOptions default_options_{non_bitcasting_callback()}; + AlgebraicSimplifierOptions default_options_; }; // Test that A + 0 is simplified to A @@ -79,6 +74,208 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[8] parameter(0) + p1 = s32[8] parameter(1) + p2 = s32[8] parameter(2) + x = s32[8] multiply(p0, p2) + y = s32[8] multiply(p1, p2) + ROOT sum = s32[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2)))); +} + +// A*C + B*C => (A+B)*C if C is a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.125) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::ConstantScalar(0.125)))); +} + +// A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + c = f32[] constant(0.125) + b = f32[4] broadcast(c), dimensions={} + x = f32[4] multiply(p0, b) + y = f32[4] multiply(p1, b) + ROOT sum = f32[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + +// A*C + B*C => (A+B)*C simplification should not happen if C is not a +// floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.3) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are +// complex numbers. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = c64[8] parameter(0) + p1 = c64[8] parameter(1) + p2 = c64[8] parameter(2) + x = c64[8] multiply(p0, p2) + y = c64[8] multiply(p1, p2) + ROOT sum = c64[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = bf16[4] parameter(0) + p1 = bf16[4] parameter(1) + c = bf16[] constant(0.125) + b = bf16[4] broadcast(c), dimensions={} + x = bf16[4] multiply(p0, b) + y = bf16[4] multiply(p1, b) + ROOT sum = bf16[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + +TEST_F(AlgebraicSimplifierTest, UnsignedDivideByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = u32[4] parameter(0) + c = u32[] constant(8) + b = u32[4] broadcast(c), dimensions={} + ROOT d = u32[4] divide(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::ShiftRightLogical( + m::Parameter(0), m::Broadcast(m::ConstantScalar(3))))); +} + +TEST_F(AlgebraicSimplifierTest, SignedDivideByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = s32[4] parameter(0) + c = s32[] constant(8) + b = s32[4] broadcast(c), dimensions={} + ROOT d = s32[4] divide(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto match_dividend_is_negative = + m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0))); + auto match_abs = m::Select(match_dividend_is_negative, + m::Negate(m::Parameter(0)), m::Parameter(0)); + auto match_shift = + m::ShiftRightLogical(match_abs, m::Broadcast(m::ConstantScalar(3))); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Select(match_dividend_is_negative, + m::Negate(match_shift), match_shift))); +} + +TEST_F(AlgebraicSimplifierTest, UnsignedRemainderByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = u32[4] parameter(0) + c = u32[] constant(8) + b = u32[4] broadcast(c), dimensions={} + ROOT r = u32[4] remainder(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::AndAnyOrder(m::Parameter(0), + m::Broadcast(m::ConstantScalar(7))))); +} + +TEST_F(AlgebraicSimplifierTest, SignedRemainderByPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p = s32[4] parameter(0) + c = s32[] constant(8) + b = s32[4] broadcast(c), dimensions={} + ROOT r = s32[4] remainder(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + auto match_dividend_is_negative = + m::Lt(m::Parameter(0), m::Broadcast(m::ConstantScalar(0))); + auto match_abs = m::Select(match_dividend_is_negative, + m::Negate(m::Parameter(0)), m::Parameter(0)); + auto match_and = + m::AndAnyOrder(match_abs, m::Broadcast(m::ConstantScalar(7))); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Select(match_dividend_is_negative, + m::Negate(match_and), match_and))); +} + // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { auto m = CreateNewVerifiedModule(); @@ -197,7 +394,7 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = m->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reduce(param, zero)); + EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), m::Op().Is(zero)))); EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); } @@ -219,7 +416,7 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Constant()))); } // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. @@ -245,7 +442,9 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); + EXPECT_THAT(root, GmockMatch(m::Add( + m::Op().Is(param0), + m::Add(m::Op().Is(constant1), m::Op().Is(constant2))))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { @@ -303,7 +502,8 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), + m::Broadcast(m::Op().Is(zero))))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { @@ -336,11 +536,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement()); } @@ -352,11 +552,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); AlgebraicSimplifier simplifier(default_options_); ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); } TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { @@ -367,11 +567,11 @@ TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); } // Test that A - 0 is simplified to A @@ -413,7 +613,8 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), + m::Negate(m::Op().Is(constant))))); } // Test that (A/B)/C is simplified to A/(B*C). @@ -435,13 +636,16 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Divide(param0, param1), param2)); + GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)), + m::Parameter(2)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Multiply(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Multiply(m::Parameter(1), m::Parameter(2))))); } // Test that A/(B/C) is simplified to (A*C)/B. @@ -462,14 +666,18 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Divide(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Divide(m::Parameter(1), m::Parameter(2))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Multiply(param0, param2), param1)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(2)), + m::Parameter(1)))); } // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). @@ -496,14 +704,16 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { EXPECT_THAT( computation->root_instruction(), - op::Divide(op::Divide(param0, param1), op::Divide(param2, param3))); + GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)), + m::Divide(m::Parameter(2), m::Parameter(3))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), - op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); + GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(3)), + m::Multiply(m::Parameter(1), m::Parameter(2))))); } // Test that A/exp(B) is simplified to A*exp(-B). @@ -523,13 +733,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Exp(param1))); + GmockMatch(m::Divide(m::Parameter(0), m::Exp(m::Parameter(1))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Exp(op::Negate(param1)))); + GmockMatch(m::Multiply(m::Parameter(0), + m::Exp(m::Negate(m::Parameter(1)))))); } // Test that A/pow(B,C) is simplified to A*pow(B,-C). @@ -550,14 +761,18 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Power(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Power(m::Parameter(1), m::Parameter(2))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + GmockMatch(m::Multiply( + m::Parameter(0), + m::Power(m::Parameter(1), m::Negate(m::Parameter(2)))))); } // Test that broadcasting is done on the right step when simplifying A/pow(B,C) @@ -579,14 +794,18 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Power(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Power(m::Parameter(1), m::Parameter(2))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); ASSERT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + GmockMatch(m::Multiply( + m::Parameter(0), + m::Power(m::Parameter(1), m::Negate(m::Parameter(2)))))); } // A / Const => A * InvertedConst @@ -608,7 +827,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Constant())); + GmockMatch(m::Multiply(m::Parameter(0), m::Constant()))); } // pow(pow(A, X), Y) => pow(A, X*Y) @@ -630,8 +849,10 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Power(base, op::Multiply(exp1, exp2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Power(m::Op().Is(base), + m::Multiply(m::Op().Is(exp1), m::Op().Is(exp2))))); } // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex @@ -794,7 +1015,7 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param1, param2)); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(1), m::Parameter(2)))); } // Test that exp(A)/exp(B) is simplified to exp(A-B) @@ -815,14 +1036,16 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Exp(param0), op::Exp(param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Exp(m::Parameter(0)), m::Exp(m::Parameter(1))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Subtract(param0, param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Exp(m::Subtract(m::Parameter(0), m::Parameter(1))))); } // Test that exp(A)*exp(B) is simplified to exp(A+B) @@ -844,13 +1067,14 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Exp(param0), op::Exp(param1))); + GmockMatch(m::Multiply(m::Exp(m::Parameter(0)), + m::Exp(m::Parameter(1))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Add(param0, param1))); + GmockMatch(m::Exp(m::Add(m::Parameter(0), m::Parameter(1))))); } // Test that pow(exp(A), B) is simplified to exp(A*B) @@ -870,13 +1094,14 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Power(op::Exp(param0), param1)); + GmockMatch(m::Power(m::Exp(m::Parameter(0)), m::Parameter(1)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Multiply(param0, param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Exp(m::Multiply(m::Parameter(0), m::Parameter(1))))); } // Test that ln(pow(A, B)) is simplified to ln(A)*B @@ -896,13 +1121,14 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Log(op::Power(param0, param1))); + GmockMatch(m::Log(m::Power(m::Parameter(0), m::Parameter(1))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Log(param0), param1)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::Parameter(1)))); } // Test that ln(exp(A)) is simplified to A @@ -919,7 +1145,8 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Log(m::Exp(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -948,12 +1175,14 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); + GmockMatch(m::Log(m::Divide(m::Exp(m::Parameter(0)), + m::Exp(m::Parameter(1)))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Subtract(m::Parameter(0), m::Parameter(1)))); } // Test that pow(A, 0) where A is a scalar is simplified to the scalar @@ -971,13 +1200,14 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_EQ(root->literal().GetFirstElement(), 1); } @@ -995,13 +1225,14 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast()); + EXPECT_THAT(root, GmockMatch(m::Broadcast())); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32)) << ShapeUtil::HumanString(root->shape()); EXPECT_EQ(root->dimensions().size(), 0); @@ -1023,7 +1254,8 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(one)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1045,12 +1277,14 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(two)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); } // Test that pow(A, -1) is simplified to 1/A. @@ -1067,13 +1301,14 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(negative_one)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); + EXPECT_THAT(root, GmockMatch(m::Divide(m::Broadcast(), m::Parameter(0)))); EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast); EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement(), 1); @@ -1112,14 +1347,59 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { // Create add computation. builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m->AddEntryComputation(builder.Build()); HloPassFix simplifier(default_options_); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Convolution(lhs, rhs)); + GmockMatch(m::Convolution(m::Op().Is(lhs), m::Op().Is(rhs)))); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); +} + +TEST_F(AlgebraicSimplifierTest, ReduceWindowIsReduceAndReshape) { + auto m = CreateNewVerifiedModule(); + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "param")); + Window window; + for (int64 i = 0; i < 4; ++i) { + WindowDimension* dim = window.add_dimensions(); + // Makes 1x2x3x1 window. + dim->set_size((i % 3) + 1); + dim->set_stride(1); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = m->AddEmbeddedComputation(builder.Build()); + } + builder.AddInstruction(HloInstruction::CreateReduceWindow( + ShapeUtil::MakeShape(F32, {1, 1, 1, 4}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), + window, add_computation)); + m->AddEntryComputation(builder.Build()); + HloPassFix simplifier(default_options_); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant()))); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Reshape(m::Reduce(m::Parameter(0), m::Constant())))); } TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { @@ -1158,10 +1438,10 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { m->AddEntryComputation(builder.Build()); HloPassFix simplifier(default_options_); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::ReduceWindow(param, op::Constant())); + GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant()))); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { @@ -1184,11 +1464,11 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { padding)); m->AddEntryComputation(builder.Build()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Pad(param, op::Constant())); + GmockMatch(m::Pad(m::Parameter(0), m::Constant()))); HloPassFix simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -1209,7 +1489,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { m->AddEntryComputation(std::move(computation)); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Reshape(op::Broadcast(op::Reshape(op)))); + GmockMatch(m::Reshape(m::Broadcast(m::Reshape(m::Op().Is(op)))))); HloPassFix simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1228,7 +1508,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert(m::Op().Is(input)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1248,7 +1529,8 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1256,34 +1538,91 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { EXPECT_THAT(computation->root_instruction(), param0); } -TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { +TEST_F(AlgebraicSimplifierTest, CopyOfReshapeOfCopyEqualsBitcast) { auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), "param")); - *param->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout({0, 1, 2, 3}); + 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}), + "param")); HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), HloOpcode::kCopy, param)); - *copy->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout({1, 2, 0, 3}); + ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}), + HloOpcode::kCopy, param)); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {0, 1}), copy)); + builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), + HloOpcode::kCopy, reshape)); auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Reshape(m::Copy(m::Parameter(0)))))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + // Verify that the copy of reshape of copy is replaced. + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, ReshapeOfCopyEqualsBitcast) { + auto m = CreateNewVerifiedModule(); + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {3, 2, 1, 0}), + "param")); + HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}), + HloOpcode::kCopy, param)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShapeWithLayout(F32, {14 * 14, 64}, {1, 0}), copy)); + + auto computation = m->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Copy(m::Parameter(0))))); + + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + // Verify that the copy of reshape of copy is replaced. + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); +} + +TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { + auto m = CreateNewVerifiedModule(); + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {0, 1, 2, 3}), + "param")); + builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShapeWithLayout(F32, {1, 14, 14, 64}, {1, 2, 0, 3}), + HloOpcode::kCopy, param)); + auto computation = m->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); + + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier1(options); ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie()); // Verify that the copy is not replaced. - EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifierOptions options2(bitcasting_callback()); + AlgebraicSimplifierOptions options2; options2.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier2(options2); - ASSERT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); + EXPECT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); // Verify that the copy is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } // Test that unary concatenates are removed. @@ -1298,7 +1637,8 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Concatenate(m::Parameter(0)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1327,15 +1667,17 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT( - computation->root_instruction(), - op::Concatenate(empty_literal, param0, param0, empty_slice, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Concatenate( + m::Op().Is(empty_literal), m::Parameter(0), m::Parameter(0), + m::Op().Is(empty_slice), m::Parameter(1)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Concatenate(param0, param0, param1)); + GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(0), + m::Parameter(1)))); } // Test that reduce of concat is simplified. @@ -1383,8 +1725,9 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { EXPECT_THAT( computation->root_instruction(), - op::Map(op::Map(op::Reduce(param0, zero), op::Reduce(param1, zero)), - op::Reduce(param2, zero))); + GmockMatch(m::Map(m::Map(m::Reduce(m::Parameter(0), m::Op().Is(zero)), + m::Reduce(m::Parameter(1), m::Op().Is(zero))), + m::Reduce(m::Parameter(2), m::Op().Is(zero))))); } // Test a concatenate with only empty operands is removed. @@ -1407,7 +1750,8 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Concatenate(empty_literal, empty_slice)); + GmockMatch(m::Concatenate(m::Op().Is(empty_literal), + m::Op().Is(empty_slice)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1434,7 +1778,8 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Parameter(1)))); } TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { @@ -1495,10 +1840,10 @@ TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + auto s = m::Slice(m::Parameter(0)); EXPECT_THAT( computation->root_instruction(), - op::Concatenate(op::Slice(param0), op::Slice(param0), op::Slice(param0), - op::Slice(param0), op::Slice(param0), op::Slice(param1))); + GmockMatch(m::Concatenate(s, s, s, s, s, m::Slice(m::Parameter(1))))); // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its // shape should have dimensions {50, 30}. EXPECT_TRUE( @@ -1524,15 +1869,17 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Copy has not been removed. - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); } // Test that a simplification which preserves layouts is performed if layout @@ -1552,9 +1899,10 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); @@ -1581,15 +1929,18 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Reshape is not replaced with a bitcast. - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); } // Test transforming reshapes and transposes of rng. @@ -1613,13 +1964,12 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier( - (AlgebraicSimplifierOptions(bitcasting_callback()))); + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{}); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - // Verify that that reshape(transpose(rng)) is replace by a single rng of the + // Verify that reshape(transpose(rng)) is replace by a single rng of the // same shape as the reshape. - EXPECT_THAT(computation->root_instruction(), op::Rng()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Rng())); EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(), reshape_shape)); } @@ -1661,10 +2011,11 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(transformable_reshape, dimensions_wrong_reshape, - layout_wrong_reshape)); + GmockMatch(m::Tuple(m::Op().Is(transformable_reshape), + m::Op().Is(dimensions_wrong_reshape), + m::Op().Is(layout_wrong_reshape)))); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); simplifier.Run(m.get()).ValueOrDie(); @@ -1672,7 +2023,8 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { // Verify that only the first reshape is replaced. EXPECT_THAT( computation->root_instruction(), - op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); + GmockMatch(m::Tuple(m::Bitcast(), m::Op().Is(dimensions_wrong_reshape), + m::Op().Is(layout_wrong_reshape)))); } // Regression test for a bug where if we failed to sink a reshape, we'd set the @@ -1693,8 +2045,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); - AlgebraicSimplifier simplifier( - (AlgebraicSimplifierOptions(bitcasting_callback()))); + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{}); m->AddEntryComputation(builder.Build()); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -1718,8 +2069,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0, 1})); - AlgebraicSimplifier simplifier( - (AlgebraicSimplifierOptions(bitcasting_callback()))); + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions{}); m->AddEntryComputation(builder.Build()); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -1741,15 +2091,17 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { @@ -1769,15 +2121,17 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { @@ -1797,12 +2151,13 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Reshape(param0))); + GmockMatch(m::Reshape(m::Reshape(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, CopiesMerged) { @@ -1823,14 +2178,16 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Copy(m::Parameter(0))))); - AlgebraicSimplifierOptions options(non_bitcasting_callback()); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, TransposesMerged) { @@ -1849,16 +2206,38 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Op().Is(transpose1)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); EXPECT_EQ(std::vector({2, 1, 0}), computation->root_instruction()->dimensions()); } +TEST_F(AlgebraicSimplifierTest, TransposeIsReshape) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[10] parameter(0) + reshaped = f32[1,1,10] reshape(f32[10] param) + transposed = f32[10,1,1] transpose(f32[1,1,10] reshaped), dimensions={2,1,0} + ROOT reshaped_again = f32[10] reshape(f32[10,1,1] transposed) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Parameter())); +} + // Test merging reshape and broadcast. TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto m = CreateNewVerifiedModule(); @@ -1873,12 +2252,13 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Broadcast(op::Reshape(param0))); + GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); } // Test merging broadcast and reshape. @@ -1895,12 +2275,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param0))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { @@ -1916,13 +2297,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { @@ -1938,12 +2319,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); EXPECT_THAT(computation->root_instruction()->dimensions(), ::testing::ElementsAre(3)); } @@ -1961,12 +2343,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); const std::vector broadcast_dims = computation->root_instruction()->dimensions(); EXPECT_EQ(1, broadcast_dims.size()); @@ -1986,13 +2369,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { @@ -2005,12 +2388,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); } @@ -2024,13 +2408,13 @@ TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); auto root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement()); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); @@ -2046,12 +2430,14 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { @@ -2064,12 +2450,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); EXPECT_EQ(Cast(computation->root_instruction()) ->iota_dimension(), 3); @@ -2085,12 +2472,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); const int64 iota_dim = Cast(computation->root_instruction()) ->iota_dimension(); @@ -2107,12 +2495,14 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); } TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { @@ -2135,7 +2525,8 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2179,12 +2570,14 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { return false; }; - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); EXPECT_TRUE(has_negative_padding(pad)); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero))))); EXPECT_FALSE( has_negative_padding(computation->root_instruction()->operand(0))); } @@ -2213,12 +2606,14 @@ TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) { AlgebraicSimplifier simplifier(default_options_); - ASSERT_THAT(computation->root_instruction(), op::Pad(param, zero)); + ASSERT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); ASSERT_TRUE(HasInteriorPadding(pad->padding_config())); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); EXPECT_FALSE( HasInteriorPadding(computation->root_instruction()->padding_config())); } @@ -2234,7 +2629,8 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2256,7 +2652,8 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2284,12 +2681,14 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Slice(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3); EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5); EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2); @@ -2315,12 +2714,14 @@ TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Reshape(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Slice(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Slice(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { @@ -2339,7 +2740,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Reshape(m::Parameter(0))))); AlgebraicSimplifier simplifier(default_options_); ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); @@ -2359,79 +2761,6 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { EXPECT_THAT(computation->root_instruction(), keys); } -TEST_F(AlgebraicSimplifierTest, ReplacePermutationSortWithScatter) { - const char* hlo_string = R"( - HloModule permutation_sort - - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = s32[64,8732]{1,0} iota(), iota_dimension=1 - sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} - gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 - ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={1} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AlgebraicSimplifierOptions options(non_bitcasting_callback()); - options.set_enable_permutation_sort_replacement(true); - AlgebraicSimplifier simplifier(options); - EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, - op::Tuple(op::Iota(), - op::Scatter(op::Iota(), - op::Concatenate(op::Iota(), op::Reshape()), - op::Reshape()))); -} - -TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortIfNonIntegral) { - // Same as ReplacePermutationSortWithScatter except that the iota has F32 - // type. - const char* hlo_string = R"( - HloModule permutation_sort - - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = f32[64,8732]{1,0} iota(), iota_dimension=1 - sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), dimensions={1} - gte = f32[64,8732]{1,0} get-tuple-element(sort), index=1 - ROOT sort2 = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(gte, values), dimensions={1} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AlgebraicSimplifierOptions options(non_bitcasting_callback()); - options.set_enable_permutation_sort_replacement(true); - AlgebraicSimplifier simplifier(options); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); -} - -TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortWrongDimensions) { - // Same as ReplacePermutationSortWithScatter except that the sort dimensions - // don't match. - const char* hlo_string = R"( - HloModule permutation_sort - - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = s32[64,8732]{1,0} iota(), iota_dimension=1 - sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} - gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 - ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={0} - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - - AlgebraicSimplifierOptions options(non_bitcasting_callback()); - options.set_enable_permutation_sort_replacement(true); - AlgebraicSimplifier simplifier(options); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); -} - TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { auto builder = HloComputation::Builder(TestName()); @@ -2451,7 +2780,8 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(keys, values0, values1)); + GmockMatch(m::Tuple(m::Op().Is(keys), m::Op().Is(values0), + m::Op().Is(values1)))); } // Test that A && True is simplified to A @@ -2667,7 +2997,7 @@ class ConvInputPaddingTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( ConvInputPaddingTestCases, ConvInputPaddingTest, ::testing::ValuesIn(std::vector{ // Merge this edge padding into the conv. @@ -2738,11 +3068,11 @@ TEST_P(ConvInputPaddingTest, DoTest) { .ValueOrDie(); builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(), - /*feature_group_count=*/1, window, - dnums) + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums) .ValueOrDie(), - lhs_pad, filter, /*feature_group_count=*/1, window, dnums, - DefaultPrecisionConfig(2))); + lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -2753,7 +3083,8 @@ TEST_P(ConvInputPaddingTest, DoTest) { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); - ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + ASSERT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(), m::Parameter()))); EXPECT_EQ(window_util::ToString(conv->window()), absl::StrCat("size=3x3 ", testcase.expected_conv_window)); } @@ -2774,7 +3105,7 @@ class ConvFilterPaddingTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( ConvFilterPaddingTestCases, ConvFilterPaddingTest, ::testing::ValuesIn(std::vector{ // Can only merge interior padding on the filter's spatial dimensions; @@ -2854,11 +3185,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), - /*feature_group_count=*/1, window, - dnums) + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums) .ValueOrDie(), - input, rhs_pad, /*feature_group_count=*/1, window, dnums, - precision_config)); + input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums, precision_config)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -2870,7 +3201,8 @@ TEST_P(ConvFilterPaddingTest, DoIt) { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); - ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + ASSERT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(), m::Parameter()))); EXPECT_EQ(window_util::ToString(conv->window()), absl::StrFormat("size=%dx%d %s", conv->operand(1)->shape().dimensions(2), @@ -3005,13 +3337,14 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve( out_shape, input, filter, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(b.Build()); - AlgebraicSimplifierOptions simplifier_options(bitcasting_callback()); + AlgebraicSimplifierOptions simplifier_options; simplifier_options.set_is_layout_sensitive(true); AlgebraicSimplifier simplifier(simplifier_options); if (!simplifier.Run(module.get()).ValueOrDie()) { @@ -3142,10 +3475,9 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { // Running simplification again should not result in any further changes. ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); - - root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(scalar_param)); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Op().Is(scalar_param)) + .WithShapeEqualTo(&slice_shape))); } // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a @@ -3176,10 +3508,9 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(forty_two)); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Op().Is(forty_two)) + .WithShapeEqualTo(&reshape_shape))); } // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). @@ -3219,7 +3550,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Create the reduce-window. Window window; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { auto* dim = window.add_dimensions(); dim->set_size(1); dim->set_padding_low(10); @@ -3248,7 +3579,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Verify the result root = computation->root_instruction(); - EXPECT_THAT(root, op::ReduceWindow(operand, op::Constant())); + EXPECT_THAT(root, + GmockMatch(m::ReduceWindow(m::Op().Is(operand), m::Constant()))); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) << ShapeUtil::HumanString(root->shape()) << " vs " << ShapeUtil::HumanString(reduce_window_shape); @@ -3304,7 +3636,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // Create the reduce-window. Window window; - for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + for (int64 i = 0; i < pad->shape().rank(); ++i) { auto* dim = window.add_dimensions(); dim->set_size(1); dim->set_padding_low(10); @@ -3333,7 +3665,8 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // Verify the result root = computation->root_instruction(); - EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant())); + EXPECT_THAT(root, GmockMatch(m::ReduceWindow(m::Convert(m::Parameter(0)), + m::Constant()))); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) << ShapeUtil::HumanString(root->shape()) << " vs " << ShapeUtil::HumanString(reduce_window_shape); @@ -3414,7 +3747,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(op::Constant(), op::Constant())); + GmockMatch(m::Tuple(m::Constant(), m::Constant()))); } // A dynamic-slice is trivial if its start indices are all zeroes and the size @@ -3425,18 +3758,22 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); + std::vector params; + for (int i = 0; i < 3; ++i) { + params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( + i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + } builder.AddInstruction(HloInstruction::CreateDynamicSlice( shape, builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + params, /*slice_sizes=*/{10, 100, 1000})); auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Parameter()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter())); } // A dynamic-update-slice is trivial if its start indices are all zeroes and the @@ -3449,28 +3786,35 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); + std::vector slice_indices, update_indices; + for (int i = 0; i < 3; ++i) { + slice_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + i + 1, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + update_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + i + 5, ShapeUtil::MakeShape(U32, {}), "update_indices"))); + } HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( slice_shape, builder.AddInstruction( HloInstruction::CreateParameter(0, full_shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + slice_indices, /*slice_sizes=*/{10, 1, 1000})); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( slice_shape, builder.AddInstruction( - HloInstruction::CreateParameter(2, slice_shape, "to_update")), - slice, - builder.AddInstruction(HloInstruction::CreateParameter( - 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); + HloInstruction::CreateParameter(4, slice_shape, "to_update")), + slice, update_indices)); auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Parameter(), op::Parameter())); + GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter(), + m::Parameter(), m::Parameter()))); } // Test that two consecutive broadcasts can be merged to one. @@ -3492,7 +3836,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_THAT(root->dimensions(), ElementsAre(2)); } @@ -3518,7 +3862,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Parameter(0)))); EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); } @@ -3538,7 +3882,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); EXPECT_EQ(Cast(root)->iota_dimension(), 2); } @@ -3559,7 +3903,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); EXPECT_EQ(Cast(root)->iota_dimension(), 2); } @@ -3577,11 +3921,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { @@ -3598,11 +3942,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { @@ -3619,7 +3963,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } @@ -3638,11 +3982,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter()); + EXPECT_THAT(root, GmockMatch(m::Parameter())); } TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { @@ -3660,11 +4004,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(1)); + EXPECT_THAT(root, GmockMatch(m::Parameter(1))); } TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { @@ -3682,11 +4026,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Slice(op::Parameter(2))); + EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(2)))); EXPECT_EQ(root->slice_starts(0), 1); EXPECT_EQ(root->slice_limits(0), 2); } @@ -3704,11 +4048,11 @@ TEST_F(AlgebraicSimplifierTest, NegateNegate) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(0)); + EXPECT_THAT(root, GmockMatch(m::Parameter(0))); } TEST_F(AlgebraicSimplifierTest, NotNot) { @@ -3724,11 +4068,11 @@ TEST_F(AlgebraicSimplifierTest, NotNot) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifierOptions options; AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(0)); + EXPECT_THAT(root, GmockMatch(m::Parameter(0))); } struct PadReduceWindowEffectiveBroadcastCase { @@ -3832,10 +4176,10 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape)); if (param.should_become_broadcast) { - EXPECT_THAT(computation->root_instruction(), op::Broadcast(::testing::_)); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Broadcast())); } else { EXPECT_THAT(computation->root_instruction(), - op::ReduceWindow(::testing::_, zero)); + GmockMatch(m::ReduceWindow(m::Op(), m::Op().Is(zero)))); } } @@ -3851,9 +4195,6 @@ PadReduceWindowEffectiveBroadcastCases() { {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6}, /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true, /*should_become_broadcast=*/false}, // - {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, - /*reduce_window_spatials=*/{5, 5}, /*prepend_a=*/true, - /*should_become_broadcast=*/true}, // {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true, /*should_become_broadcast=*/false}, // @@ -3864,11 +4205,62 @@ PadReduceWindowEffectiveBroadcastCases() { return *cases; } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( PadReduceWindowEffectiveBroadcastInstantiation, PadReduceWindowEffectiveBroadcastTest, ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases())); +class BatchDotStrengthReductionTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + ::testing::tuple> {}; +TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { + auto module = CreateNewVerifiedModule(); + int m, k, n; + PrimitiveType element_type; + std::tie(m, k, n, element_type) = GetParam(); + + Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k}); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n}); + HloComputation::Builder builder(TestName()); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, lhs_shape, "lhs")); + auto rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, rhs_shape, "rhs")); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + dot_dnums.add_lhs_batch_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(4); + dot_dnums.add_rhs_contracting_dimensions(3); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); + const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; + const bool computation_should_be_modified = dot_should_be_transformed; + EXPECT_EQ(changed, computation_should_be_modified); + bool has_no_dot = true; + for (const auto& hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kDot) { + has_no_dot = false; + break; + } + } + EXPECT_EQ(has_no_dot, dot_should_be_transformed); +} + +INSTANTIATE_TEST_SUITE_P( + BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, + ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), + ::testing::Values(1, 2), ::testing::Values(F32, BF16))); + class DotStrengthReductionTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< @@ -3921,7 +4313,7 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { EXPECT_EQ(has_no_dot, dot_should_be_transformed); } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( DotStrengthReductionTestInstantiation, DotStrengthReductionTest, ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Bool(), @@ -3989,11 +4381,12 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); - auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0)); - auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1)); - auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2)); - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2)); + auto match_dot_0 = m::Dot(m::Slice(m::Constant()), m::Parameter(0)); + auto match_dot_1 = m::Dot(m::Slice(m::Constant()), m::Parameter(1)); + auto match_dot_2 = m::Dot(m::Slice(m::Constant()), m::Parameter(2)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2))); } // Test that we transform @@ -4052,13 +4445,14 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); - auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant())); - auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant())); - auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant())); - auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant())); - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2), - match_dot_3)); + auto match_dot_0 = m::Dot(m::Parameter(0), m::Slice(m::Constant())); + auto match_dot_1 = m::Dot(m::Parameter(1), m::Slice(m::Constant())); + auto match_dot_2 = m::Dot(m::Parameter(2), m::Slice(m::Constant())); + auto match_dot_3 = m::Dot(m::Parameter(3), m::Slice(m::Constant())); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Add(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2), + match_dot_3))); } DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { @@ -4081,9 +4475,10 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { HloInstruction* const update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); HloInstruction* const start_indices = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0({}))); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - dslice_shape, operand, update, start_indices)); + dslice_shape, operand, update, + std::initializer_list({start_indices}))); const HloComputation* const computation = m->AddEntryComputation(builder.Build()); @@ -4092,9 +4487,9 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { EXPECT_THAT(computation->root_instruction(), operand); } -INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, - DotOfConcatSimplificationTest, - ::testing::ValuesIn(kDotOfConcatTestSpecs)); +INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation, + DotOfConcatSimplificationTest, + ::testing::ValuesIn(kDotOfConcatTestSpecs)); struct DotOfGatherTestSpec { int64 m; @@ -4136,14 +4531,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int32 start_row = (spec.lcd == 0) ? 0 : spec.s; int32 start_col = (spec.lcd == 0) ? spec.s : 0; - const auto start_indices = + std::vector start_indices = { builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR0(start_row))), + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(start_col)))}; int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; - Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + std::vector slice_sizes = {slice_row_size, slice_col_size}; + Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes); auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ds_shape, lhs, start_indices, {slice_row_size, slice_col_size})); + ds_shape, lhs, start_indices, slice_sizes)); int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n; int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; @@ -4175,8 +4573,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { HloOpcode::kDynamicSlice); } else { EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), - op::Concatenate())); + GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), + m::Constant(), m::Constant()))); } } @@ -4214,14 +4612,17 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int32 start_row = (spec.rcd == 0) ? 0 : spec.s; int32 start_col = (spec.rcd == 0) ? spec.s : 0; - const auto start_indices = + std::vector start_indices = { + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(start_row))), builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR0(start_col)))}; int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; - Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + std::vector slice_sizes = {slice_row_size, slice_col_size}; + Shape ds_shape = ShapeUtil::MakeShape(F32, slice_sizes); auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ds_shape, rhs, start_indices, {slice_row_size, slice_col_size})); + ds_shape, rhs, start_indices, slice_sizes)); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(spec.lcd); @@ -4245,8 +4646,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { HloOpcode::kDynamicSlice); } else { EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), - op::Concatenate())); + GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), + m::Constant(), m::Constant()))); } } @@ -4294,9 +4695,102 @@ std::vector DotOfGatherPositiveNegativeTests() { return all; } -INSTANTIATE_TEST_CASE_P( +INSTANTIATE_TEST_SUITE_P( DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, ::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); +TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) { + const char* hlo_string = R"( +HloModule module + +reducer { + parameter.1 = f32[] parameter(0) + parameter.3 = f32[] parameter(2) + add.2 = f32[] add(parameter.1, parameter.3) + parameter.0 = f32[] parameter(1) + parameter.2 = f32[] parameter(3) + add.3 = f32[] add(parameter.0, parameter.2) + ROOT tuple.4 = (f32[], f32[]) tuple(add.2, add.3) +} + +ENTRY entry { + parameter.6 = (f32[], f32[]) parameter(0) + get-tuple-element.10 = f32[] get-tuple-element(parameter.6), index=0 + get-tuple-element.11 = f32[] get-tuple-element(parameter.6), index=1 + constant = f32[] constant(0) + ROOT reduce = (f32[], f32[]) reduce(get-tuple-element.10, get-tuple-element.11, constant, constant), dimensions={}, to_apply=reducer +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Tuple( + m::Reshape(m::GetTupleElement(m::Parameter(), 0)), + m::Reshape(m::GetTupleElement(m::Parameter(), 1))))); +} + +TEST_F(AlgebraicSimplifierTest, TupleReduceBroadcast) { + const char* hlo_string = R"( +HloModule module + +reducer { + parameter.1 = f32[] parameter(0) + parameter.3 = f32[] parameter(2) + mul.2 = f32[] add(parameter.1, parameter.3) + parameter.0 = f32[] parameter(1) + parameter.2 = f32[] parameter(3) + add.3 = f32[] add(parameter.0, parameter.2) + ROOT tuple.4 = (f32[], f32[]) tuple(mul.2, add.3) +} + +ENTRY entry { + parameter.6 = (f32[0, 10, 10], f32[0, 10, 10]) parameter(0) + get-tuple-element.10 = f32[0, 10, 10] get-tuple-element(parameter.6), index=0 + get-tuple-element.11 = f32[0, 10, 10] get-tuple-element(parameter.6), index=1 + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT reduce = (f32[10, 10], f32[10, 10]) reduce(get-tuple-element.10, get-tuple-element.11, constant.0, constant.1), dimensions={0}, to_apply=reducer +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Tuple(m::Broadcast(m::ConstantScalar(0)), + m::Broadcast(m::ConstantScalar(1))))); +} + +TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1}), "param")); + HloInstruction* broadcast = + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {0, 1}), param, {1})); + + // Create a reshape with zero sized result and without layout. + Shape reshaped_shape = ShapeUtil::MakeShape(F32, {0}); + reshaped_shape.clear_layout(); + builder.AddInstruction( + HloInstruction::CreateReshape(reshaped_shape, broadcast)); + + std::unique_ptr module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Constant()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index ef5e211646e7b0b66b8e6c09948be58063422943..6cb0e985e57016e5a22fba50c3e3ad6970f1b178 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -142,13 +142,13 @@ StatusOr> AllocationTracker::DeconstructTuple( // We only need to care about replica id 0 here, since the GlobalDataHandle is // the same for all buffers across replicas. const ShapedBuffer* shaped_buffer = replicated_buffers[0]; - if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { + if (!shaped_buffer->on_host_shape().IsTuple()) { return InvalidArgument("global data handle %d is not a tuple", data.handle()); } // If the on-host representation is a tuple, then the on-device one should be // as well. - TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape())); + TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple()); if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { return Unimplemented("Deconstructing nested tuples is not implemented."); diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index c11452a6fbd49a1fc382d11d24a7d7b7eeab0bcc..f8dff6a700cc9d5843053e3d451a7b005539ca26 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -36,31 +37,67 @@ namespace { namespace m = match; -// If the argument instruction is a CRS in the sequence -// AR -> Convert -> Add -> CRS -// then return the AR in the sequence. -// TODO(b/117554291): Rewrite this to recognize more general patterns, -// not just the specific one of AR -> Add -> Convert -> CRS. +// Checks if the argument instruction is an AllReduce, followed by a certain +// sequence of instructions and then a CRS. It must be possible to move +// the AR past each instruction in the sequence. Returns the CRS, which is the +// last instruction in the sequence. absl::optional MatchesArCrsPattern( HloInstruction* instruction) { - HloInstruction *ar, *convert, *add, *crs; - if (Match(instruction, - m::CrossReplicaSum( - &crs, m::Add(&add, m::Op(), - m::Convert(&convert, - m::CrossReplicaSum(&ar, m::Op()))))) && - ar->users().size() == 1 && ar->shape().element_type() == BF16 && - convert->shape().element_type() == F32 && !crs->all_reduce_id()) { - return ar; + auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool { + if (instruction->user_count() != 1) { + return false; + } + switch (instruction->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kReshape: + return true; + case HloOpcode::kConvert: + // Can be moved across if both input and output is either float or + // integer (e.g. S32<->U32 or F32<->BF16) + return ShapeUtil::ElementIsFloating(instruction->shape()) == + ShapeUtil::ElementIsFloating(instruction->operand(0)->shape()); + case HloOpcode::kAdd: + case HloOpcode::kSubtract: + case HloOpcode::kMultiply: + // Only supported for floating point operands. + return ShapeUtil::ElementIsFloating(instruction->shape()); + default: + return false; + } + }; + + auto computation_is_addition = [](HloComputation* c) { + return c->instruction_count() == 3 && + Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); + }; + + if (!instruction->IsCrossModuleAllReduce() || + !computation_is_addition(instruction->called_computations()[0]) || + instruction->user_count() != 1) { + return absl::nullopt; + } + auto next = instruction->users()[0]; + while (!next->IsCrossReplicaAllReduce()) { + if (can_ar_move_past_instruction(next)) { + next = next->users()[0]; + } else { + return absl::nullopt; + } + } + if (!Cast(next)->IsNoop() && + computation_is_addition(next->called_computations()[0])) { + return absl::optional(next); + } else { + return absl::nullopt; } - return absl::optional(); } } // namespace absl::optional ArCrsCombiner::WhileFromBodyParameter( HloInstruction* instruction) { - CHECK(HloOpcode::kParameter == instruction->opcode()); + CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); HloComputation* computation = instruction->parent(); auto caller_instructions = call_graph_->GetComputationCallers(computation); if (caller_instructions.size() == 1) { @@ -69,7 +106,7 @@ absl::optional ArCrsCombiner::WhileFromBodyParameter( return caller_instruction; } } - return absl::optional(); + return absl::nullopt; } std::vector ArCrsCombiner::GetAllTuples( @@ -120,7 +157,7 @@ bool ArCrsCombiner::TupleElementsComputeSameValue( return false; } for (auto tuple : tuples) { - CHECK(tuple->opcode() == HloOpcode::kTuple); + CHECK_EQ(tuple->opcode(), HloOpcode::kTuple); if (!InstructionsComputeSameValue(tuple->mutable_operand(i1), tuple->mutable_operand(i2), visited_pairs)) { @@ -160,12 +197,14 @@ bool ArCrsCombiner::InstructionsComputeSameValue( if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) { return false; } - if (opcode1 == HloOpcode::kConstant || i1->IsCrossModuleAllReduce()) { - return i1->Identical( - *i2, - /*eq_operands=*/std::equal_to(), - /*eq_computations=*/std::equal_to(), - /*layout_sensitive=*/false); + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + if (i1->IsCrossModuleAllReduce()) { + return i1->Identical(*i2, + /*eq_operands=*/std::equal_to(), + eq_computations, + /*layout_sensitive=*/false); } visited_pairs->emplace(min_uid, max_uid); for (int i = 0; i < operands1.size(); ++i) { @@ -175,22 +214,38 @@ bool ArCrsCombiner::InstructionsComputeSameValue( return false; } } + if (opcode1 == HloOpcode::kParameter) { + // In the general case, we don't try to prove equality of parameters. + // We only try in the context of get-tuple-element + // (see TupleElementsComputeSameValue). + return false; + } if (opcode1 == HloOpcode::kGetTupleElement) { - if (i1->tuple_index() == i2->tuple_index()) { - return true; - } - return TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), + return i1->tuple_index() == i2->tuple_index() || + TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), i2->tuple_index(), visited_pairs); } - return true; + // Don't check that the operands are identical, because Identical can + // return false for instructions that compute the same value but are not + // identical, which we don't want. We have checked the arguments with + // InstructionsComputeSameValue earlier. + auto eq_instructions = [](const HloInstruction* i1, + const HloInstruction* i2) -> bool { return true; }; + return i1->Identical(*i2, eq_instructions, eq_computations, + /*layout_sensitive=*/false); } void ArCrsCombiner::GroupAllReducesById(HloModule* module) { for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { - auto ar = MatchesArCrsPattern(instruction); - if (ar) { - all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar); + auto maybe_crs = MatchesArCrsPattern(instruction); + if (maybe_crs) { + auto crs = *maybe_crs; + int64 ar_id = *(instruction->all_reduce_id()); + if (crs_reserved_map_.find(crs) == crs_reserved_map_.end()) { + all_reduce_map_[ar_id].push_back(instruction); + crs_reserved_map_[crs] = ar_id; + } } } } @@ -198,20 +253,25 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { for (auto it : all_reduce_map_) { + auto all_reduce_id = it.first; auto instruction_vec = it.second; CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); - auto instr_0 = instruction_vec[0]; - auto add_0 = instr_0->users()[0]->users()[0]; - CHECK(HloOpcode::kAdd == add_0->opcode()); - for (int i = 1; i < instruction_vec.size(); ++i) { auto instr_i = instruction_vec[i]; - auto add_i = instr_i->users()[0]->users()[0]; - CHECK(HloOpcode::kAdd == add_i->opcode()); + auto next_0 = instr_0->users()[0]; + auto next_i = instr_i->users()[0]; absl::flat_hash_map visited_pairs; - if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) { - all_reduce_map_.erase(it.first); + while (true) { + if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) { + all_reduce_map_.erase(all_reduce_id); + break; + } + if (next_0->IsCrossReplicaAllReduce()) { + break; + } + next_0 = next_0->users()[0]; + next_i = next_i->users()[0]; } } } @@ -221,55 +281,51 @@ StatusOr ArCrsCombiner::RewriteGraph() { if (all_reduce_map_.empty()) { return false; } - - auto computation_is_addition = [](HloComputation* c) { - return c->instruction_count() == 3 && - Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); - }; - for (auto it : all_reduce_map_) { auto instruction_vec = it.second; for (auto all_reduce : instruction_vec) { auto parent_computation = all_reduce->parent(); - auto convert = all_reduce->users()[0]; - auto add = convert->users()[0]; - auto crs = add->users()[0]; - - if (!computation_is_addition(all_reduce->called_computations()[0]) || - !computation_is_addition(crs->called_computations()[0])) { - continue; - } - HloInstruction* other_summand = (add->operands()[0] == convert) - ? add->operands()[1] - : add->operands()[0]; - // Remove the AllReduce and replace the CRS with: - // AllReduce - (other_summand * (num_spatial_partitions_ - 1)) - TF_CHECK_OK( - all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0))); - crs->set_all_reduce_id(all_reduce->all_reduce_id()); - auto new_shape = crs->shape(); - HloInstruction* to_subtract; - if (num_spatial_partitions_ == 2) { - to_subtract = other_summand; - } else { - Literal partitions_minus_1_lit = Literal(new_shape); - partitions_minus_1_lit.PopulateWithValue( - num_spatial_partitions_ - 1); - auto partitions_minus_1_const = parent_computation->AddInstruction( - HloInstruction::CreateConstant(partitions_minus_1_lit.Clone())); - to_subtract = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - new_shape, HloOpcode::kMultiply, other_summand, - partitions_minus_1_const)); - } - auto sub = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - new_shape, HloOpcode::kSubtract, crs, to_subtract)); - TF_CHECK_OK(crs->ReplaceAllUsesWith(sub)); + auto all_reduce_id = all_reduce->all_reduce_id(); + auto prev = all_reduce->mutable_operand(0); + auto next = all_reduce->users()[0]; + TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev)); TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); + while (!next->IsCrossReplicaAllReduce()) { + switch (next->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kReshape: + case HloOpcode::kConvert: + case HloOpcode::kMultiply: + break; + case HloOpcode::kAdd: + case HloOpcode::kSubtract: { + auto other_operand = (next->operands()[0] == prev) + ? next->operands()[1] + : next->operands()[0]; + // To move the AR past the addition/subtraction, we need to divide + // other_operand by the number of spatial partitions. + auto shape = other_operand->shape(); + Literal lit(shape); + lit.PopulateWithValue(num_spatial_partitions_); + auto divisor = parent_computation->AddInstruction( + HloInstruction::CreateConstant(lit.Clone())); + auto division = + parent_computation->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDivide, other_operand, divisor)); + TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); + break; + } + default: + LOG(FATAL) << "Unexpected instruction: " << next->ToShortString(); + } + prev = next; + next = next->users()[0]; + } + // The AllReduce and the CRS are combined to an all-core AllReduce. + next->set_all_reduce_id(all_reduce_id); } } - return true; } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index f6a7ef76ec3b76972d1b2c7fb548cecfb9423160..e61ef5d4f9072979a6c356a9456c91e19405b01e 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -25,9 +25,12 @@ limitations under the License. namespace xla { -// Combine an AllReduce and a CrossReplicaSum when they are close to each other -// in the graph, to use an efficient CrossReplicaSum implementation that -// fully utilizes the interconnect bandwidth. +// When the HLO graph contains a cross-module AllReduce, followed by some simple +// linear operations, followed by a cross-replica AllReduce, we can combine the +// CMAR and the CRAR, to use an efficient AllReduce implementation that fully +// utilizes the interconnect bandwidth. +// Such sequences appear in spatially partitioned models. +// This pass must run right after spatial partitioning. class ArCrsCombiner : public HloModulePass { public: ArCrsCombiner(int num_spatial_partitions) @@ -80,6 +83,11 @@ class ArCrsCombiner : public HloModulePass { // Map from all-reduce ids to the all reduce instructions. absl::flat_hash_map> all_reduce_map_; + // Map from a CRS instruction to the all-reduce ID of the AR paired with the + // CRS. Sometimes, several ARs in the code could be paired with the same CRS. + // We use this map to pick a single AR/CRS path to rewrite. + absl::flat_hash_map crs_reserved_map_; + std::unique_ptr call_graph_; }; diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 9d5eaf63ccf32cd78b8c11f12f9bccdfd1fec3e0..5152f0dc884a153f9b0ade06acd479832d87ff25 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -32,8 +32,8 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{1, 2}, {3, 4}}) ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) } )"; @@ -48,13 +48,50 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } +TEST_F(ArCrsCombinerTest, SameValueTestBasecase2) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + ROOT %tuple = (f32[], f32[]) tuple(%x, %x) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase3) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[], y: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %tuple = (f32[], f32[]) tuple(%x, %y) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) { const char* module_str = R"( HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple1 = (f32[2,2]) tuple(%constant.f32) %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2) @@ -69,13 +106,53 @@ ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[0:1]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesDontMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[1:2]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) { const char* module_str = R"( HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0 @@ -97,7 +174,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 @@ -119,8 +196,8 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { %p = f32[2,2] parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{2, 3}, {4, 5}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{2, 3}, {4, 5}}) %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 @@ -149,7 +226,7 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) @@ -158,7 +235,7 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -186,7 +263,7 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) @@ -195,8 +272,8 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {7, 8}}) + %constant.f32.1 = f32[2,2] constant({{3, 4}, {5, 6}}) + %constant.f32.2 = f32[2,2] constant({{3, 4}, {7, 8}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -224,8 +301,8 @@ HloModule foobar %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { %x = (f32[2,2], f32[2,2]) parameter(0) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {1, 2}}) + %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant({{3, 4}, {1, 2}}) %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1) @@ -234,7 +311,7 @@ HloModule foobar } ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { - %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}}) %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body } @@ -249,11 +326,27 @@ ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); } -TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) { +void CompareReplicaGroups(const std::vector& groups_before, + const std::vector& groups_after) { + ASSERT_EQ(groups_before.size(), groups_after.size()); + for (int i = 0; i < groups_before.size(); ++i) { + // Somewhat verbose way to compare the replica_ids, because EqualsProto + // is not available in the open-source build. + auto group_before = groups_before[i]; + std::vector ids_before(group_before.replica_ids().begin(), + group_before.replica_ids().end()); + auto group_after = groups_after[i]; + std::vector ids_after(group_after.replica_ids().begin(), + group_after.replica_ids().end()); + EXPECT_EQ(ids_before, ids_after); + } +} + +TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) { const char* module_str = R"( HloModule foobar -%binary_add (a: bf16[], b: bf16[]) -> bf16[] { +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { %a = bf16[] parameter(0) %b = bf16[] parameter(1) ROOT %add = bf16[] add(%a, %b) @@ -265,49 +358,258 @@ HloModule foobar ROOT %add = f32[] add(%x, %y) } -ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { - %p = f32[2,2] parameter(0) - %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) - %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) +ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { + %p = bf16[] parameter(0) + %constant.bf16 = bf16[] constant(1) + + %all-reduce.ar.1 = bf16[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%convert.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=1} + %convert.2 = f32[] + convert(%all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%convert.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; - %cross-replica-sum.ar.1 = bf16[2,2] - cross-replica-sum(%constant.bf16), + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Convert(op::Parameter())), + op::AllReduce(op::Convert(op::Constant())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] { + %a = f32[2,1] parameter(0) + %b = f32[2,1] parameter(1) + ROOT %add = f32[2,1] add(%a, %b) +} + +%sum.2 (x: f32[2], y: f32[2]) -> f32[2] { + %x = f32[2] parameter(0) + %y = f32[2] parameter(1) + ROOT %add = f32[2] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) { + %p = f32[2,1] parameter(0) + + %all-reduce.ar.1 = f32[2,1] + all-reduce(%p), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.1, sharding={maximal device=0} - %convert.1 = f32[2,2] - convert(%cross-replica-sum.ar.1), + %bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1) + %all-reduce.1 = f32[2] + all-reduce(%bitcast.1), + replica_groups={{0,1}}, + to_apply=%sum.2, sharding={maximal device=0} - %add.1 = f32[2,2] + + %all-reduce.ar.2 = f32[2,1] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=1} + %bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2) + %all-reduce.2 = f32[2] + all-reduce(%bitcast.2), + replica_groups={{0,1}}, + to_apply=%sum.2, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Bitcast(op::Parameter())), + op::AllReduce(op::Bitcast(op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=0} + %multiply.1 = f32[] + multiply(%all-reduce.ar.1, %constant.f32), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%multiply.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=1} + %multiply.2 = f32[] + multiply(%all-reduce.ar.2, %constant.f32), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%multiply.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant())), + op::AllReduce(op::Multiply(op::Parameter(), op::Constant())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32 = f32[] constant(2) + + %all-reduce.ar.1 = bf16[] + all-reduce(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%all-reduce.ar.1), + sharding={maximal device=0} + %add.1 = f32[] add(%constant.f32, %convert.1), sharding={maximal device=0} - %cross-replica-sum.1 = f32[2,2] - cross-replica-sum(%add.1), + %all-reduce.1 = f32[] + all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} - %cross-replica-sum.ar.2 = bf16[2,2] - cross-replica-sum(%constant.bf16), + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=1} - %convert.2 = f32[2,2] - convert(%cross-replica-sum.ar.2), + %convert.2 = f32[] + convert(%all-reduce.ar.2), sharding={maximal device=1} - %add.2 = f32[2,2] + %add.2 = f32[] add(%constant.f32, %convert.2), sharding={maximal device=1} - %cross-replica-sum.2 = f32[2,2] - cross-replica-sum(%add.2), + %all-reduce.2 = f32[] + all-reduce(%add.2), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=1} - ROOT %tuple = (f32[2,2], f32[2,2]) - tuple(%cross-replica-sum.1, %cross-replica-sum.2), + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), sharding={{maximal device=0}, {maximal device=1}} } )"; @@ -320,31 +622,24 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { ArCrsCombiner combiner(2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Subtract(op::CrossReplicaSum(), op::Constant()), - op::Subtract(op::CrossReplicaSum(), op::Constant()))); - auto sub = module->entry_computation()->root_instruction()->operands()[0]; - auto crs_after = sub->operands()[0]; + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()), + op::Convert())), + op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()), + op::Convert())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); - ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size()); - for (int i = 0; i < replica_groups_before.size(); ++i) { - // Somewhat verbose way to compare the replica_ids, because EqualsProto - // is not available in the open-source build. - auto group_before = replica_groups_before[i]; - std::vector ids_before(group_before.replica_ids().begin(), - group_before.replica_ids().end()); - auto group_after = replica_groups_after[i]; - std::vector ids_after(group_after.replica_ids().begin(), - group_after.replica_ids().end()); - EXPECT_EQ(ids_before, ids_after); - } + CompareReplicaGroups(replica_groups_before, replica_groups_after); } TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) { const char* module_str = R"( HloModule foobar -%binary_add (a: bf16[], b: bf16[]) -> bf16[] { +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { %a = bf16[] parameter(0) %b = bf16[] parameter(1) ROOT %add = bf16[] add(%a, %b) @@ -356,50 +651,515 @@ HloModule foobar ROOT %add = f32[] add(%x, %y) } -ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { - %p = f32[2,2] parameter(0) - %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) - %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) - %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.bf16 = bf16[] constant(1) + %constant.f32.1 = f32[] constant(2) + %constant.f32.2 = f32[] constant(3) - %cross-replica-sum.ar.1 = bf16[2,2] - cross-replica-sum(%constant.bf16), + %all-reduce.ar.1 = bf16[] + all-reduce(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=0} - %convert.1 = f32[2,2] - convert(%cross-replica-sum.ar.1), + %convert.1 = f32[] + convert(%all-reduce.ar.1), sharding={maximal device=0} - %add.1 = f32[2,2] + %add.1 = f32[] add(%constant.f32.1, %convert.1), sharding={maximal device=0} - %cross-replica-sum.1 = f32[2,2] - cross-replica-sum(%add.1), + %all-reduce.1 = f32[] + all-reduce(%add.1), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=0} - %cross-replica-sum.ar.2 = bf16[2,2] - cross-replica-sum(%constant.bf16), + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), replica_groups={{0},{1}}, all_reduce_id=1, - to_apply=%binary_add, + to_apply=%sum.bf16, sharding={maximal device=1} - %convert.2 = f32[2,2] - convert(%cross-replica-sum.ar.2), + %convert.2 = f32[] + convert(%all-reduce.ar.2), sharding={maximal device=1} - %add.2 = f32[2,2] + %add.2 = f32[] add(%constant.f32.2, %convert.2), sharding={maximal device=1} - %cross-replica-sum.2 = f32[2,2] - cross-replica-sum(%add.2), + %all-reduce.2 = f32[] + all-reduce(%add.2), replica_groups={{0,1}}, to_apply=%sum.f32, sharding={maximal device=1} - ROOT %tuple = (f32[2,2], f32[2,2]) - tuple(%cross-replica-sum.1, %cross-replica-sum.2), + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +TEST_F(ArCrsCombinerTest, ArThenCrsDontCrash) { + const char* module_str = R"( +HloModule foobar + +%sum.1 (a: f32[], b: f32[]) -> f32[] { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + ROOT %add = f32[] add(%a, %b) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%all-reduce.ar.1), + replica_groups={{0,1}}, + to_apply=%sum.1, + sharding={maximal device=0} + %multiply.1 = f32[] + multiply(%all-reduce.1, %constant.f32), + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.1, + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%all-reduce.ar.2), + replica_groups={{0,1}}, + to_apply=%sum.1, + sharding={maximal device=1} + %multiply.2 = f32[] + multiply(%all-reduce.2, %constant.f32), + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Parameter()), + op::AllReduce(op::Parameter()))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteMultipleAdds) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.1 = f32[] constant(1) + %constant.2 = f32[] constant(2) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add.11 = f32[] + add(%constant.1, %all-reduce.ar.1), + sharding={maximal device=0} + %add.12 = f32[] + add(%constant.2, %add.11), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%add.12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add.21 = f32[] + add(%constant.1, %all-reduce.ar.2), + sharding={maximal device=0} + %add.22 = f32[] + add(%constant.2, %add.21), + sharding={maximal device=0} + %all-reduce.2 = f32[] + all-reduce(%add.22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Divide(op::Constant(), op::Constant()), + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Parameter()))), + op::AllReduce(op::Add( + op::Divide(op::Constant(), op::Constant()), + op::Add(op::Divide(op::Constant(), op::Constant()), + op::Parameter()))))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) { + const char* module_str = R"( +HloModule foobar + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %constant.f32 = f32[] constant(123) + + %all-reduce.ar.1 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=0} + %sub.1 = f32[] + subtract(%constant.f32, %all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%sub.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum.f32, + sharding={maximal device=1} + %sub.2 = f32[] + subtract(%constant.f32, %all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%sub.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()), + op::Parameter())), + op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()), + op::Parameter())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %add11 = f32[] + add(%ar11, %const1), + sharding={maximal device=0} + %ar12 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=0} + %add12 = f32[] + add(%add11, %ar12), + sharding={maximal device=0} + %crs1 = f32[] + all-reduce(%add12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %ar21 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=1} + %add21 = f32[] + add(%ar21, %const1), + sharding={maximal device=1} + %ar22 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=1} + %add22 = f32[] + add(%add21, %ar22), + sharding={maximal device=1} + %crs2 = f32[] + all-reduce(%add22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%crs1, %crs2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())), + op::Divide(op::AllReduce(), op::Constant()))), + op::AllReduce(op::Add( + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())), + op::Divide(op::AllReduce(), op::Constant()))))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) { + const char* module_str = R"( +HloModule foobar + +%sum (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { + %p = f32[] parameter(0) + %const1 = f32[] constant(1) + %const2 = f32[] constant(2) + + %ar11 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=0} + %ar12 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=0} + %add11 = f32[] + add(%ar12, %const1), + sharding={maximal device=0} + %add12 = f32[] + add(%ar11, %add11), + sharding={maximal device=0} + %crs1 = f32[] + all-reduce(%add12), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=0} + + %ar21 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%sum, + sharding={maximal device=1} + %ar22 = f32[] + all-reduce(%p), + replica_groups={{0},{1}}, + all_reduce_id=2, + to_apply=%sum, + sharding={maximal device=1} + %add21 = f32[] + add(%ar22, %const1), + sharding={maximal device=1} + %add22 = f32[] + add(%ar21, %add21), + sharding={maximal device=1} + %crs2 = f32[] + all-reduce(%add22), + replica_groups={{0,1}}, + to_apply=%sum, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%crs1, %crs2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Parameter(), + op::Divide(op::Add(op::AllReduce(), op::Constant()), + op::Constant()))), + op::AllReduce(op::Add( + op::Parameter(), + op::Divide(op::Add(op::AllReduce(), op::Constant()), + op::Constant()))))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + CompareReplicaGroups(replica_groups_before, replica_groups_after); +} + +TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) { + const char* module_str = R"( +HloModule foobar + +%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) { + %p = bf16[] parameter(0) + %constant.bf16 = bf16[] constant(1) + + %all-reduce.ar.1 = bf16[] + all-reduce(%p), + replica_groups={{0}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=0} + %convert.1 = f32[] + convert(%all-reduce.ar.1), + sharding={maximal device=0} + %all-reduce.1 = f32[] + all-reduce(%convert.1), + replica_groups={{0}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %all-reduce.ar.2 = bf16[] + all-reduce(%constant.bf16), + replica_groups={{0}}, + all_reduce_id=1, + to_apply=%sum.bf16, + sharding={maximal device=1} + %convert.2 = f32[] + convert(%all-reduce.ar.2), + sharding={maximal device=1} + %all-reduce.2 = f32[] + all-reduce(%convert.2), + replica_groups={{0}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[], f32[]) + tuple(%all-reduce.1, %all-reduce.2), sharding={{maximal device=0}, {maximal device=1}} } )"; diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 5c180cbdd492031e133b81149f0f4698619b7788..215e8ced4bb3f98a26ac4eb9912a7fd4d917852f 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -57,6 +57,16 @@ int BackendOptions::intra_op_parallelism_threads() const { return intra_op_parallelism_threads_; } +BackendOptions& BackendOptions::set_allowed_devices( + const absl::optional>& allowed_devices) { + allowed_devices_ = allowed_devices; + return *this; +} + +const absl::optional>& BackendOptions::allowed_devices() const { + return allowed_devices_; +} + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct Backend::EigenThreadPoolWrapper { @@ -76,8 +86,9 @@ struct Backend::EigenThreadPoolWrapper { const BackendOptions& options) { se::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); - TF_ASSIGN_OR_RETURN(auto stream_executors, - PlatformUtil::GetStreamExecutors(platform)); + TF_ASSIGN_OR_RETURN( + auto stream_executors, + PlatformUtil::GetStreamExecutors(platform, options.allowed_devices())); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto computation_placer, @@ -104,12 +115,10 @@ StatusOr Backend::BorrowStream(int device_ordinal) { StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(mu_); - if (0 == stream_pools_.count(executor)) { - stream_pools_.emplace(std::piecewise_construct, - std::forward_as_tuple(executor), - std::forward_as_tuple()); + if (!stream_pools_.contains(executor)) { + stream_pools_.emplace(executor, absl::make_unique()); } - return stream_pools_.at(executor).BorrowStream(executor); + return stream_pools_.at(executor)->BorrowStream(executor); } Backend::Backend(se::Platform* platform, Compiler* compiler, diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index a2dafbe803f8bd5f23e4e9f3f6d3e6f744c9fab9..c35f033dc0180409ae3888c2050021da83f5c72a 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -18,9 +18,11 @@ limitations under the License. #include #include +#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -53,9 +55,16 @@ class BackendOptions { BackendOptions& set_intra_op_parallelism_threads(int num_threads); int intra_op_parallelism_threads() const; + // Sets the allowed_devices for selectively constructing stream executors + // on the platform. + BackendOptions& set_allowed_devices( + const absl::optional>& allowed_devices); + const absl::optional>& allowed_devices() const; + private: se::Platform* platform_ = nullptr; int intra_op_parallelism_threads_ = -1; + absl::optional> allowed_devices_; }; // Class which encapsulates an XLA backend. It includes everything necessary @@ -167,7 +176,8 @@ class Backend { tensorflow::mutex mu_; // Mapping from stream executor to stream pools, used by `BorrowStream` above. - std::map stream_pools_ GUARDED_BY(mu_); + absl::flat_hash_map> + stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index f70f6ddfec69c0113a1afe2073a2392098f49456..e5f5c3edb2ac0c217317fbf809463aa31af9af59 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -107,19 +107,37 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { } std::unique_ptr Mean( - int64 element_count, HloInstruction* operand, + HloInstruction* element_count, HloInstruction* operand, const std::function)>& add_instruction) { - HloInstruction* elem_count_recip = - add_instruction(HloInstruction::CreateBroadcast( - operand->shape(), - add_instruction(HloInstruction::CreateConvert( - ShapeUtil::MakeShape(operand->shape().element_type(), {}), - add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(1.0 / element_count))))), - {})); - return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, - operand, elem_count_recip); + auto broadcast = add_instruction( + HloInstruction::CreateBroadcast(operand->shape(), element_count, {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kDivide, + operand, broadcast); + } + + std::unique_ptr DynamicElementCountPerFeature( + HloInstruction* operand, int64 feature_index, + const std::function)>& + add_instruction) { + auto elements_per_feature_u32 = add_instruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + + for (int64 i = 0; i < operand->shape().rank(); ++i) { + if (i == feature_index) { + continue; + } + auto dynamic_dimension_size = + add_instruction(HloInstruction::CreateGetDimensionSize( + ShapeUtil::MakeShape(U32, {}), operand, i)); + elements_per_feature_u32 = add_instruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(U32, {}), HloOpcode::kMultiply, + dynamic_dimension_size, elements_per_feature_u32)); + } + + return HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + elements_per_feature_u32); } // Replaces the existing HLO instruction old_instruction, with @@ -195,9 +213,6 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape operand_shape = operand->shape(); PrimitiveType ptype = operand_shape.element_type(); int64 feature_index = batch_norm->feature_index(); - const int64 feature_count = operand_shape.dimensions(feature_index); - const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - int64 elements_per_feature_int64 = size_in_elements / feature_count; HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -214,12 +229,15 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } } + auto elements_per_feature = + add(DynamicElementCountPerFeature(operand, feature_index, add)); + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); @@ -243,13 +261,13 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( add_reduce_computation)); // E[X]. - auto mean = add(Mean(elements_per_feature_int64, sum, add)); + auto mean = add(Mean(elements_per_feature, sum, add)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add)); + auto square_mean = add(Mean(elements_per_feature, squared_sum, add)); // E^2[X]. auto mean_square = @@ -339,7 +357,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } @@ -458,9 +476,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( int64 feature_index = batch_norm->feature_index(); - const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); - const int64 feature_count = activation_shape.dimensions(feature_index); - const int64 elements_per_feature_int64 = size_in_elements / feature_count; + auto elements_per_feature = + add(DynamicElementCountPerFeature(activation, feature_index, add)); auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); @@ -477,7 +494,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( std::vector dimensions_without_feature; - for (int64 i = 0; i < ShapeUtil::Rank(activation_shape); ++i) { + for (int64 i = 0; i < activation_shape.rank(); ++i) { if (i != feature_index) { dimensions_without_feature.push_back(i); } @@ -553,15 +570,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = add( - Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); + scale_times_rsqrt_var_add_epsilon = + add(Mean(elements_per_feature, scale_times_rsqrt_var_add_epsilon, add)); - auto elements_per_feature_literal = - LiteralUtil::CreateR0(elements_per_feature_int64); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal.Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, add(HloInstruction::CreateBroadcast( activation_shape, elements_per_feature, {}))); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index 08cf8026177d77ff98cca5e5d168ac3194936b35..8e8fbbd935b154e5a77d68e60d861601d740bf03 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -36,7 +36,21 @@ limitations under the License. namespace xla { namespace { -using BatchNormExpanderTest = HloTestBase; +class BatchNormExpanderTest : public HloTestBase { + protected: + // BatchNorm should have a dynamic sized dividor for mean operations. + int64 CountGetDimensionSize(const HloModule& module) { + int64 count = 0; + for (HloComputation* comp : module.computations()) { + for (HloInstruction* inst : comp->instructions()) { + if (inst->opcode() == HloOpcode::kGetDimensionSize) { + count++; + } + } + } + return count; + } +}; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -68,6 +82,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); + EXPECT_EQ(CountGetDimensionSize(*module), 3); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } @@ -110,6 +125,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); + EXPECT_EQ(CountGetDimensionSize(*module), 3); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index e9d30fc03c1c3194de577e6683b36a95641694d9..e62d72b323bd1d113e9d87bf8602bfb434c40d61 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -34,8 +34,8 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum which can have a tuple output. - Status HandleCrossReplicaSum(HloInstruction* crs) override; + // Special handling for all-reduce which can have a tuple output. + Status HandleAllReduce(HloInstruction* crs) override; static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { @@ -176,8 +176,7 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { return TryFoldBF16Conversions(hlo); } -Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( - HloInstruction* crs) { +Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) { if (crs->IsCrossModuleAllReduce()) { // Cross-module all-reduce has side effect. return Status::OK(); @@ -191,7 +190,7 @@ Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( } // If the output is not a tuple, we don't need special handling. - if (!ShapeUtil::IsTuple(crs->shape())) { + if (!crs->shape().IsTuple()) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 4ce351acc2c359773e618da70360c96faf5ca379..2232a2cbdfe0cf64dc4fb10d4598c0ad8b51ee5e 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -38,7 +38,7 @@ class TestBFloat16Support : public BFloat16Support { hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement || - hlo.opcode() == HloOpcode::kCrossReplicaSum) { + hlo.opcode() == HloOpcode::kAllReduce) { return true; } return false; @@ -49,7 +49,7 @@ class TestBFloat16Support : public BFloat16Support { hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement || - hlo.opcode() == HloOpcode::kCrossReplicaSum) { + hlo.opcode() == HloOpcode::kAllReduce) { return true; } return false; @@ -58,7 +58,7 @@ class TestBFloat16Support : public BFloat16Support { bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement || - hlo.opcode() == HloOpcode::kCrossReplicaSum) { + hlo.opcode() == HloOpcode::kAllReduce) { return true; } return false; @@ -213,7 +213,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { EXPECT_EQ(tuple->operand(1), convert0); } -TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { +TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) { auto builder = HloComputation::Builder(TestName()); auto module = CreateNewVerifiedModule(); @@ -236,11 +236,10 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateParameter(1, f32_shape, "b")); - HloInstruction* crs = - builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum, /*replica_groups=*/{}, /*barrier=*/"", - /*all_reduce_id=*/absl::nullopt)); + HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum, + /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index b8a8f844eff17a95d4073f53495e0027c481f558..d1b14d604f0559b6b18f7d1fba127669c241c8a3 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -362,8 +362,8 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { } // TODO(b/112040122): Correctly normalize variadic reduce. if ((hlo->opcode() == HloOpcode::kSort || - hlo->opcode() == HloOpcode::kCrossReplicaSum) && - ShapeUtil::IsTuple(hlo->shape())) { + hlo->opcode() == HloOpcode::kAllReduce) && + hlo->shape().IsTuple()) { return HandleMultipleOutputs(hlo); } return HandleInstruction(hlo); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 9f97d18c565c7915b9f9346f0c6330cdc3c707e9..551ac4be73a7630d213a53ca3606aa7f890cd794 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -232,7 +232,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32); } -TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) { auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("sum"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( @@ -253,11 +253,10 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateParameter(1, bf16_shape, "b")); - HloInstruction* crs = - builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, - /*replica_groups=*/{}, /*barrier=*/"", - /*all_reduce_id=*/absl::nullopt)); + HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, + /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 63d4572f2028c462df1cac9d5e4ee616e407f37b..bab63f66d83b712d756078bef84926eed235f6b5 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -276,8 +276,8 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( *use.instruction, use.operand_number)) { if (use.instruction->opcode() == HloOpcode::kTuple || - (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && - ShapeUtil::IsTuple(use.instruction->shape()))) { + (use.instruction->opcode() == HloOpcode::kAllReduce && + use.instruction->shape().IsTuple())) { ShapeIndex use_output_index{use.operand_number}; for (int64 i : use.operand_index) { use_output_index.push_back(i); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 5be7141aae423adb4fe2f39262e463ff25ae8234..a9b5d9916e400b39039248098c22a715e44ccfd2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -209,7 +209,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) { rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")))); auto reduction = module->AddEmbeddedComputation(rb.Build()); HloInstruction* all_reduce = - builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + builder.AddInstruction(HloInstruction::CreateAllReduce( ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction, /*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/1)); HloInstruction* gte0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 8d7c62447852fd946440c41389300a92377c471f..e1b91b500191c7756f3d1a4b160a0dd1e09cfe7d 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -86,10 +86,9 @@ std::vector ColorInterferenceGraph( // first, but it would be good to investigate other ordering heuristics too. std::vector nodes(node_count); std::iota(nodes.begin(), nodes.end(), 0); - std::sort(nodes.begin(), nodes.end(), - [&interference_map](const int64 i, const int64 j) { - return interference_map[i].size() > interference_map[j].size(); - }); + absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); const int64 kColorUnassigned = -1; std::vector assigned_colors(node_count, kColorUnassigned); @@ -138,8 +137,8 @@ Status GatherComputationsByAllocationType( worklist.pop_front(); const HloComputation* computation = worklist_front.first; bool is_thread_local = worklist_front.second; - bool in_thread_local_set = thread_local_set.count(computation) > 0; - bool in_global_set = global_set.count(computation) > 0; + bool in_thread_local_set = thread_local_set.contains(computation); + bool in_global_set = global_set.contains(computation); // If the computation has already been added to the respective set, then // nothing to do. @@ -186,7 +185,7 @@ Status GatherComputationsByAllocationType( worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. break; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: @@ -207,9 +206,9 @@ Status GatherComputationsByAllocationType( // Add the computations to the vectors in post order. for (auto* computation : module->MakeComputationPostOrder()) { - if (thread_local_set.count(computation) > 0) { + if (thread_local_set.contains(computation)) { thread_local_computations->push_back(computation); - } else if (global_set.count(computation) > 0) { + } else if (global_set.contains(computation)) { global_computations->push_back(computation); } // If the computation is not reachable from the entry computation, then it @@ -219,13 +218,6 @@ Status GatherComputationsByAllocationType( return Status::OK(); } -size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { - uint64 h = std::hash()(s.index()); - h = tensorflow::Hash64Combine(h, std::hash()(s.offset())); - h = tensorflow::Hash64Combine(h, std::hash()(s.size())); - return h; -} - string BufferAllocation::Slice::ToString() const { return absl::StrCat("{index:", index(), ", offset:", offset_, ", size:", size_, "}"); @@ -240,7 +232,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size) { VLOG(4) << "Trying to add " << buffer << " to allocation #" << index(); - CHECK(assigned_buffers_.count(&buffer) == 0) + CHECK(!assigned_buffers_.contains(&buffer)) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; CHECK_LE(offset, size_) << "LogicalBuffer " << buffer @@ -279,11 +271,12 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto_assigned->set_offset(buffer_offset_size.second.offset); proto_assigned->set_size(buffer_offset_size.second.size); } - std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(), - [](const BufferAllocationProto::Assigned& assign1, - const BufferAllocationProto::Assigned& assign2) { - return assign1.logical_buffer_id() < assign2.logical_buffer_id(); - }); + absl::c_sort(*proto.mutable_assigned(), + [](const BufferAllocationProto::Assigned& assign1, + const BufferAllocationProto::Assigned& assign2) { + return assign1.logical_buffer_id() < + assign2.logical_buffer_id(); + }); return proto; } @@ -315,10 +308,10 @@ string BufferAllocation::ToString() const { for (const auto& buffer_offset_size : assigned_buffers_) { sorted_buffers.push_back(buffer_offset_size.first); } - std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); + absl::c_sort(sorted_buffers, + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); StrAppend(&output, absl::StrFormat( @@ -346,7 +339,7 @@ const PointsToSet& BufferAssignment::GetPointsToSet( bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const { TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); - return allocation_index_for_buffer_.count(&buffer) > 0; + return allocation_index_for_buffer_.contains(&buffer); } const BufferAllocation& BufferAssignment::GetAssignedAllocation( @@ -401,7 +394,7 @@ bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction, const ShapeIndex& index) const { for (const LogicalBuffer* buffer : GetPointsToSet(instruction).element(index)) { - if (allocation_index_for_buffer_.count(buffer) > 0) { + if (allocation_index_for_buffer_.contains(buffer)) { return true; } } @@ -459,8 +452,7 @@ bool BufferAssignment::SharesSliceAtIndex( bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, const HloInstruction* hlo_b) const { - using SliceSet = - flat_hash_set; + using SliceSet = flat_hash_set; // Gets the slices all of instr's subshapes. If any subshape doesn't have an // assigned slice, returns the empty set. auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { @@ -487,10 +479,9 @@ bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, // didn't return the empty set) for both HLOs, and the two resulting sets of // slices are disjoint. return !slices_a.empty() && !slices_b.empty() && - std::none_of(slices_a.begin(), slices_a.end(), - [&](const BufferAllocation::Slice& slice) { - return slices_b.count(slice) > 0; - }); + absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) { + return slices_b.contains(slice); + }); } StatusOr @@ -519,7 +510,7 @@ BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, void BufferAssignment::AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, int64 offset, int64 size) { - CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer)) + CHECK(!allocation_index_for_buffer_.contains(&buffer)) << "LogicalBuffer " << buffer << " already has an allocation."; CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty()) << "Non-reusable allocation already assigned a buffer: " @@ -761,7 +752,8 @@ namespace { bool MayInterfereAcrossSubcomputations(BufferAssignment* assignment, const LogicalBuffer& a_buffer, const LogicalBuffer& b_buffer) { - auto call_graph = assignment->liveness().hlo_ordering().call_graph(); + const CallGraph& call_graph = + assignment->liveness().hlo_ordering().call_graph(); const HloInstruction* a_ancestor; const HloInstruction* b_ancestor; std::tie(a_ancestor, b_ancestor) = @@ -960,35 +952,35 @@ Status BufferAssigner::AssignBuffersForComputation( // operands (assuming operands are the same/larger size) enabling the // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. - std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [has_sequential_order, &liveness, &post_order_position, assignment]( - const LogicalBuffer* a, const LogicalBuffer* b) { - // Primary sort is by decreasing buffer size. - const int64 a_size = assignment->buffer_size_(*a); - const int64 b_size = assignment->buffer_size_(*b); - if (a_size != b_size) { - return a_size > b_size; // use ">" for decreasing size. - } - // Otherwise live out buffers come before others, if the - // instructions are sequentially ordered. - if (has_sequential_order) { - const bool a_live_out = liveness.MaybeLiveOut(*a); - const bool b_live_out = liveness.MaybeLiveOut(*b); - if (a_live_out != b_live_out) { - return a_live_out; - } - } - // Final tiebreaker is in instruction post order. - return post_order_position.at(a->instruction()) < - post_order_position.at(b->instruction()); - }); + absl::c_sort(sorted_buffers, + [has_sequential_order, &liveness, &post_order_position, + assignment](const LogicalBuffer* a, const LogicalBuffer* b) { + // Primary sort is by decreasing buffer size. + const int64 a_size = assignment->buffer_size_(*a); + const int64 b_size = assignment->buffer_size_(*b); + if (a_size != b_size) { + return a_size > b_size; // use ">" for decreasing size. + } + // Otherwise live out buffers come before others, if the + // instructions are sequentially ordered. + if (has_sequential_order) { + const bool a_live_out = liveness.MaybeLiveOut(*a); + const bool b_live_out = liveness.MaybeLiveOut(*b); + if (a_live_out != b_live_out) { + return a_live_out; + } + } + // Final tiebreaker is in instruction post order. + return post_order_position.at(a->instruction()) < + post_order_position.at(b->instruction()); + }); // BufferAllocations are necessarily created in decreasing size order. Keep // indices of previously created BufferAllocations in allocation_indices. std::vector allocation_indices; for (const LogicalBuffer* buffer : sorted_buffers) { VLOG(3) << "Assigning allocation to: " << *buffer; - if (colocated_buffers.count(buffer) > 0) { + if (colocated_buffers.contains(buffer)) { // Colocated buffers are currently assigned in an earlier pass. VLOG(3) << "Skipping colocated buffer: " << *buffer; continue; @@ -1020,10 +1012,14 @@ Status BufferAssigner::AssignBuffersForComputation( // callers. BufferAllocation* allocation = assignment->NewAllocation(*buffer, buffer_size); + bool parameter_has_alias = + assignment->module().input_output_alias_config().ParameterHasAlias( + instruction->parameter_number(), buffer->index()); allocation->set_entry_computation_parameter( - instruction->parameter_number(), buffer->index()); - VLOG(3) << "New allocation #" << allocation->index() - << " for entry computation parameter: " << *buffer; + instruction->parameter_number(), buffer->index(), + parameter_has_alias); + VLOG(3) << "Mark allocation #" << allocation->index() + << " as entry computation parameter: " << *buffer; continue; } @@ -1036,7 +1032,7 @@ Status BufferAssigner::AssignBuffersForComputation( continue; } - if (ShapeUtil::IsTuple(buffer->shape())) { + if (buffer->shape().IsTuple()) { BufferAllocation* allocation = assignment->NewAllocation(*buffer, buffer_size); allocation->set_is_tuple(true); @@ -1056,7 +1052,7 @@ Status BufferAssigner::AssignBuffersForComputation( assignment->GetAllSlices(operand, /*index=*/{})) { BufferAllocation* allocation = assignment->GetMutableAllocation(operand_slice.index()); - if (colocated_allocations.count(allocation->index()) == 0) { + if (!colocated_allocations.contains(allocation->index())) { // TODO(b/32491382) Colocated buffers are currently assigned in an // earlier pass, and so can break the "increasing allocation size" // invariant in this function (causing this CHECK to fail). However, @@ -1087,7 +1083,7 @@ Status BufferAssigner::AssignBuffersForComputation( // Instructions are iterated in increasing buffer size, so any // previously create allocation must be large enough to hold this // instruction's output (with the exception of colocated buffers). - if (colocated_allocations.count(allocation->index()) == 0) { + if (!colocated_allocations.contains(allocation->index())) { // TODO(b/32491382) Colocated buffers are currently assigned in an // earlier pass, and so can break the "increasing allocation size" // invariant in this function (causing this CHECK to fail). However, @@ -1313,10 +1309,10 @@ std::vector ComputePeakMemoryLogicalBuffers( live_buffers.end()); // Stabily sort the live buffers. - std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); + absl::c_sort(live_buffers_vector, + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); return live_buffers_vector; } @@ -1376,7 +1372,7 @@ void BufferAssigner::AddSetToColocatedBufferSets( std::vector overlap_set_indices; for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) { for (const LogicalBuffer* buffer : colocated_set) { - if ((*colocated_buffer_sets)[index].count(buffer) > 0) { + if ((*colocated_buffer_sets)[index].contains(buffer)) { VLOG(5) << "Found overlap with existing set on buffer " << buffer->ToString() << "\n" << ColocatedBufferSetsToString((*colocated_buffer_sets)[index], @@ -1425,12 +1421,14 @@ BufferAssigner::MergeColocatedBufferSets( << colocated_buffer_sets.size(); // Returns true if the given buffer is for the entry parameter. - auto is_entry_parameter = [](const LogicalBuffer& buffer) { + auto is_readonly_entry_parameter = [](const LogicalBuffer& buffer) { auto* instruction = buffer.instruction(); auto* computation = instruction->parent(); auto* module = computation->parent(); return instruction->opcode() == HloOpcode::kParameter && - computation == module->entry_computation(); + computation == module->entry_computation() && + !module->input_output_alias_config().ParameterHasAlias( + instruction->parameter_number(), buffer.index()); }; std::vector set_can_be_merged(colocated_buffer_sets.size(), true); @@ -1452,7 +1450,7 @@ BufferAssigner::MergeColocatedBufferSets( for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { for (auto& buffer : colocated_buffer_sets[i]) { if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer) || + is_readonly_entry_parameter(*buffer) || buffer->instruction()->opcode() == HloOpcode::kConstant) { set_can_be_merged[i] = false; break; @@ -1539,15 +1537,16 @@ void BufferAssigner::BuildColocatedBufferSets( VLOG(4) << "Input/Output Alias Config: "; VLOG(4) << module->input_output_alias_config(); module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { std::vector colocated_set; AddBufferToColocatedSet(module->entry_computation()->root_instruction(), output_index, points_to_analysis, &colocated_set); AddBufferToColocatedSet( - module->entry_computation()->parameter_instruction(param_number), - param_index, points_to_analysis, &colocated_set); + module->entry_computation()->parameter_instruction( + alias.parameter_number), + alias.parameter_index, points_to_analysis, &colocated_set); AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); @@ -1741,10 +1740,6 @@ void BufferAssigner::AssignColocatedBufferSets( // module-level scope, we can allow buffers to be shared across // computations (in some cases). allocation = assignment->NewAllocation(*buffer, buffer_size); - if (entry_parameter_number >= 0) { - allocation->set_entry_computation_parameter( - entry_parameter_number, *entry_parameter_shape_idx); - } if (is_constant) { allocation->set_constant(true); } @@ -1758,6 +1753,16 @@ void BufferAssigner::AssignColocatedBufferSets( } colocated_buffers->insert(buffer); } + + // If an allocation contains a parameter, set corresponding fields. + if (entry_parameter_number >= 0) { + bool parameter_has_alias = + assignment->module().input_output_alias_config().ParameterHasAlias( + entry_parameter_number, *entry_parameter_shape_idx); + allocation->set_entry_computation_parameter(entry_parameter_number, + *entry_parameter_shape_idx, + parameter_has_alias); + } } } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 0a9fdede803e84ca42472259084615c031b206eb..448dec3b1aa0c0f85e1060a70e965fcf3952c320 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -96,7 +96,11 @@ class BufferAllocation { // Whether this allocation is readonly i.e. backed by memory we cannot write // to. bool is_readonly() const { - return is_entry_computation_parameter() || is_constant(); + // Entry parameters are generally readonly, except when they are aliased + // with any output. + return (is_entry_computation_parameter() && + !is_parameter_aliased_with_output_) || + is_constant(); } bool is_tuple() const { return is_tuple_; } @@ -186,9 +190,10 @@ class BufferAllocation { end > other.offset_; } - struct Hasher { - size_t operator()(Slice s) const; - }; + template + friend H AbslHashValue(H h, const Slice& s) { + return H::combine(std::move(h), s.index(), s.offset(), s.size()); + } string ToString() const; @@ -273,8 +278,10 @@ class BufferAllocation { void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size); void set_entry_computation_parameter(int64 parameter_number, - ShapeIndex param_shape_index) { + ShapeIndex param_shape_index, + bool parameter_aliased_with_output) { is_entry_computation_parameter_ = true; + is_parameter_aliased_with_output_ = parameter_aliased_with_output; parameter_number_ = parameter_number; param_shape_index_ = std::move(param_shape_index); } @@ -304,6 +311,9 @@ class BufferAllocation { // outlast the computation. bool is_entry_computation_parameter_ = false; + // Whether this entry computation parameter is aliased with output. + bool is_parameter_aliased_with_output_ = false; + // If this allocation holds an entry computation parameter, this field // indicates the index (starting from 0) of the parameter. int64 parameter_number_ = 0; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8f482e6ba8c3e71c9980be5e6947ea61f3b4ef29..580bc2f43384006eab8711490689a200fc887d37 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -309,7 +310,7 @@ class BufferAssignmentTest : public HloTestBase { static bool BuffersDistinct(const std::vector& a, const std::vector& b, const BufferAssignment& assignment) { - std::set a_slices; + absl::flat_hash_set a_slices; for (const HloInstruction* instruction : a) { if (assignment.HasTopLevelAllocation(instruction)) { a_slices.insert( @@ -319,8 +320,8 @@ static bool BuffersDistinct(const std::vector& a, for (const HloInstruction* instruction : b) { if (assignment.HasTopLevelAllocation(instruction)) { - if (a_slices.count(assignment.GetUniqueTopLevelSlice(instruction) - .ConsumeValueOrDie())) { + if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction) + .ConsumeValueOrDie())) { return false; } } @@ -464,6 +465,40 @@ TEST_F(BufferAssignmentTest, Basic) { GetAssignedOutputAllocation(*buffers, sub); } +TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) { + // If an input buffer and output buffer aliases, the input buffer can be + // reused for other intermediate results. + // + // param0[100] ----- (neg1) -- (neg2) + // | | + // + -------- Aliased ---------+ + + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "p0")); + auto neg_1 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param)); + auto neg_2 = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1)); + + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias( + {}, 0, {}, HloInputOutputAliasConfig::kUserAlias)); + + auto buffers = RunBufferAssignment(module.get()); + + BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param); + BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {}); + BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {}); + + // Everything use one buffer. + EXPECT_EQ(param_buffer.index(), neg_1_buffer.index()); + EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index()); +} + TEST_F(BufferAssignmentTest, AddCannotReuse) { // Pass in a special rule to indicate that "add" cannot reuse any buffer. // @@ -2485,9 +2520,9 @@ while_body { get-tuple-element.3 = s32[] get-tuple-element(state), index=0 constant.2 = s32[] constant(128) add.5 = s32[] add(get-tuple-element.3, constant.2) - constant.3 = s32[3]{0} constant({0, 0, 0}) - dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3) - dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3) + constant.3 = s32[] constant(0) + dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9) } diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 40825a78716b1c0b9fb0121787977d275891c0f8..23b9af0281b0d5ee1ef6ca2315f0cc1042285609 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -52,8 +52,8 @@ class BufferLivenessTest : public HloTestBase { // interfere. Precondition: 'a' and 'b' are array-shaped. bool InstructionsMayInterfere(const BufferLiveness& liveness, HloInstruction* a, HloInstruction* b) { - EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); - EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); + EXPECT_FALSE(a->shape().IsTuple()); + EXPECT_FALSE(b->shape().IsTuple()); return liveness.MayInterfere( GetBuffer(liveness, /*instruction=*/a, /*index=*/{}), GetBuffer(liveness, /*instruction=*/b, /*index=*/{})); @@ -66,8 +66,8 @@ class BufferLivenessTest : public HloTestBase { HloInstruction* a, HloInstruction* b, const ShapeIndex& index) { // Check that top-level shapes are tuple and tuple element shapes are equal. - EXPECT_TRUE(ShapeUtil::IsTuple(a->shape())); - EXPECT_TRUE(ShapeUtil::IsTuple(b->shape())); + EXPECT_TRUE(a->shape().IsTuple()); + EXPECT_TRUE(b->shape().IsTuple()); EXPECT_TRUE( ShapeUtil::Compatible(ShapeUtil::GetSubshape(a->shape(), index), ShapeUtil::GetSubshape(b->shape(), index))); @@ -638,10 +638,10 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); // Create output tuple. builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -794,10 +794,10 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); // Create output tuple. auto tuple_root = builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc index fdf822c666b15afbc7553ca89d4f92ab08201869..b1abba20689915b03304aacd7a5fcca5443c2c60 100644 --- a/tensorflow/compiler/xla/service/buffer_value.cc +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -29,8 +29,8 @@ BufferValue::BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id) : id_(id) { const Shape& shape = ShapeUtil::GetSubshape(instruction->shape(), index); - is_array_ = ShapeUtil::IsArray(shape); - is_tuple_ = ShapeUtil::IsTuple(shape); + is_array_ = shape.IsArray(); + is_tuple_ = shape.IsTuple(); } BufferValue::~BufferValue() {} diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 7987343bfaf1069fd550909d127e4b11f2124701..94af788c54f6c722997311bec50da3ed93aa3cee 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -58,7 +58,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: @@ -236,6 +236,41 @@ void CallGraph::SetCallContexts() { } } +void CallGraph::SetNodeDepths() { + std::queue worklist; + + // Initialize node depths to -1. + for (CallGraphNode& node : nodes_) { + node.set_depth(-1); + } + + // Initialize worklist with all roots of the call graph (computations without + // callers). + for (const HloComputation* computation : module_->computations()) { + CallGraphNode& node = GetNode(computation); + if (node.callers().empty()) { + node.set_depth(0); + worklist.push(&node); + } + } + + while (!worklist.empty()) { + CallGraphNode* node = worklist.front(); + worklist.pop(); + for (const HloComputation* callee : node->callees()) { + CallGraphNode& callee_node = GetNode(callee); + if (callee_node.depth() < node->depth() + 1) { + callee_node.set_depth(node->depth() + 1); + worklist.push(&callee_node); + } + } + } + + for (CallGraphNode& node : nodes_) { + CHECK_NE(node.depth(), -1); + } +} + /* static */ std::unique_ptr CallGraph::Build(const HloModule* module) { // Constructor for CallGraph is private so absl::make_unique can't be used. @@ -271,6 +306,8 @@ std::unique_ptr CallGraph::Build(const HloModule* module) { } call_graph->SetCallContexts(); + call_graph->SetNodeDepths(); + XLA_VLOG_LINES(1, call_graph->ToString()); return call_graph; @@ -352,15 +389,38 @@ CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, // Iterate through the callee->caller chains and find the earliest common // element. - for (HloInstruction* a_ancestor = a; a_ancestor != nullptr; - a_ancestor = next_caller(a_ancestor)) { - for (HloInstruction* b_ancestor = b; b_ancestor != nullptr; - b_ancestor = next_caller(b_ancestor)) { - if (a_ancestor->parent() == b_ancestor->parent()) { - return {a_ancestor, b_ancestor}; + HloInstruction* a_ancestor = a; + HloInstruction* b_ancestor = b; + int a_depth = GetNode(a->parent()).depth(); + int b_depth = GetNode(b->parent()).depth(); + + // Advance a_ancestor (b_ancestor) up the call chain until the call depth of + // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller + // reduces the depth by exactly one. + if (a_depth > b_depth) { + for (int i = 0; i < a_depth - b_depth; ++i) { + a_ancestor = next_caller(a_ancestor); + if (a_ancestor == nullptr) { + return {nullptr, nullptr}; + } + } + } else if (b_depth > a_depth) { + for (int i = 0; i < b_depth - a_depth; ++i) { + b_ancestor = next_caller(b_ancestor); + if (b_ancestor == nullptr) { + return {nullptr, nullptr}; } } } + + while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) { + if (a_ancestor->parent() == b_ancestor->parent()) { + return {a_ancestor, b_ancestor}; + } + + a_ancestor = next_caller(a_ancestor); + b_ancestor = next_caller(b_ancestor); + } return {nullptr, nullptr}; } diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 05c7c998738f861ee804d1ec87bfa5fb17ddfb74..c02ffda575278905f6549b362e5e7d94f5713b36 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -121,6 +121,11 @@ class CallGraphNode { // Returns the context in which this computation is called. CallContext context() const { return context_; } + // Returns the depth of this node in the call graph. The depth is defined as + // the length of the longest call chain from a computation with no callers + // (usually the entry computation node) to this node. + int depth() const { return depth_; } + string ToString() const; private: @@ -130,6 +135,9 @@ class CallGraphNode { // Sets the context in which this computation is called. void set_context(CallContext value) { context_ = value; } + // Sets the depth of this node in the graph. + void set_depth(int value) { depth_ = value; } + // Adds a callsite which calls this computation. Updates callers to include // the calling computation. void AddCallerCallSite(const CallSite& caller_callsite); @@ -164,6 +172,9 @@ class CallGraphNode { // The context in which this computation is called. CallContext context_ = CallContext::kNone; + + // The depth of this node in the call graph. + int depth_ = 0; }; // The call graph for an HLO module. The graph includes a node for each @@ -245,9 +256,16 @@ class CallGraph { private: CallGraph(const HloModule* module); + // Not copyable. + CallGraph(const CallGraph&) = delete; + CallGraph& operator=(const CallGraph&) = delete; + // Sets the call contexts for every node in the graph. void SetCallContexts(); + // Sets the call node depths for every node in the graph. + void SetNodeDepths(); + // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS // post order (callee before caller) calling visitor_func on each node. Adds // nodes to 'visited' as each node is visited. Skips nodes already in diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index a3ac2568b0f3eec8556a42dbe3c2c64bd8564468..5de724f8924b78008ba4c56603b61bf93fbc5e7c 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -102,6 +102,7 @@ TEST_F(CallGraphTest, SingletonComputation) { const CallGraphNode& node = call_graph->GetNode(computation); EXPECT_EQ(computation, node.computation()); + EXPECT_EQ(node.depth(), 0); EXPECT_TRUE(node.callsites().empty()); EXPECT_TRUE(node.callees().empty()); EXPECT_TRUE(node.caller_callsites().empty()); @@ -122,11 +123,13 @@ TEST_F(CallGraphTest, UnreachableComputation) { EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_node.depth(), 0); EXPECT_EQ(entry_computation, entry_node.computation()); EXPECT_EQ(CallContext::kSequential, entry_node.context()); const CallGraphNode& unreachable_node = call_graph->GetNode(unreachable_computation); + EXPECT_EQ(unreachable_node.depth(), 0); EXPECT_EQ(unreachable_computation, unreachable_node.computation()); EXPECT_EQ(CallContext::kSequential, unreachable_node.context()); } @@ -145,6 +148,7 @@ TEST_F(CallGraphTest, ParallelComputation) { const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(entry_node.depth(), 0); EXPECT_EQ(CallContext::kSequential, entry_node.context()); EXPECT_EQ(5, entry_node.callsites().size()); EXPECT_EQ(1, entry_node.callees().size()); @@ -153,6 +157,7 @@ TEST_F(CallGraphTest, ParallelComputation) { const CallGraphNode& map_node = call_graph->GetNode(map_computation); EXPECT_EQ(map_computation, map_node.computation()); + EXPECT_EQ(map_node.depth(), 1); EXPECT_EQ(CallContext::kParallel, map_node.context()); EXPECT_TRUE(map_node.callsites().empty()); EXPECT_TRUE(map_node.callees().empty()); @@ -234,6 +239,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite); const CallGraphNode& sub_node = call_graph->GetNode(subcomputation); + EXPECT_EQ(sub_node.depth(), 1); EXPECT_EQ(CallContext::kBoth, sub_node.context()); } @@ -264,6 +270,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { EXPECT_EQ(3, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_node.depth(), 0); EXPECT_EQ(entry_computation, entry_node.computation()); EXPECT_EQ(1, entry_node.callsites().size()); @@ -275,11 +282,13 @@ TEST_F(CallGraphTest, ComputationWithConditional) { EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite); const CallGraphNode& true_node = call_graph->GetNode(true_computation); + EXPECT_EQ(true_node.depth(), 1); EXPECT_TRUE(true_node.callees().empty()); EXPECT_EQ(1, true_node.callers().size()); EXPECT_EQ(entry_computation, true_node.callers()[0]); const CallGraphNode& false_node = call_graph->GetNode(false_computation); + EXPECT_EQ(false_node.depth(), 1); EXPECT_TRUE(false_node.callees().empty()); EXPECT_EQ(1, false_node.callers().size()); EXPECT_EQ(entry_computation, false_node.callers()[0]); @@ -332,9 +341,21 @@ TEST_F(CallGraphTest, ComplexGraph) { EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + const CallGraphNode& a_node = call_graph->GetNode(a_computation); + const CallGraphNode& b_node = call_graph->GetNode(b_computation); + const CallGraphNode& c_node = call_graph->GetNode(c_computation); + const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); + + // Verify depths. + EXPECT_EQ(entry_node.depth(), 0); + EXPECT_EQ(a_node.depth(), 1); + EXPECT_EQ(b_node.depth(), 2); + EXPECT_EQ(c_node.depth(), 3); + EXPECT_EQ(cond_node.depth(), 2); + // Entry computation has one while instruction calling two computations // (cond_computation and a_computation). - const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); ASSERT_EQ(1, entry_node.callsites().size()); const std::vector& called_computations = entry_node.callsites()[0].called_computations(); @@ -342,7 +363,6 @@ TEST_F(CallGraphTest, ComplexGraph) { UnorderedElementsAre(cond_computation, a_computation)); EXPECT_EQ(CallContext::kSequential, entry_node.context()); - const CallGraphNode& c_node = call_graph->GetNode(c_computation); EXPECT_TRUE(c_node.callsites().empty()); EXPECT_THAT(c_node.callers(), UnorderedElementsAre(a_computation, b_computation)); @@ -364,7 +384,7 @@ TEST_F(CallGraphTest, ComplexGraph) { // Verify visitation order of some computations in the graph. auto index_of = [&visited](const HloComputation* comp) { - auto it = std::find(visited.begin(), visited.end(), comp); + auto it = absl::c_find(visited, comp); EXPECT_NE(it, visited.end()); return std::distance(visited.begin(), it); }; diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 3c2d1ae6d82ebc6c10d52194fd1cec5e291025f7..b517495f2ea0c75679685c67f757ff586f8c79e3 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -72,7 +72,7 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) { } Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { - if (opaque_to_channel_.count(handle.handle()) == 0) { + if (!opaque_to_channel_.contains(handle.handle())) { return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; @@ -94,7 +94,7 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { } Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { - if (opaque_to_channel_.count(handle.handle()) == 0) { + if (!opaque_to_channel_.contains(handle.handle())) { return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index 52037bf9b52556c6aa2e66dd3209e25cf085cfe3..89e17eba36f23077ce4cf0704e7455b76bee68d1 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status.h" @@ -83,7 +84,8 @@ class ChannelTracker { // Mapping from ChannelHandle value to the corresponding registered // Channel object. - std::map opaque_to_channel_ GUARDED_BY(channel_mutex_); + absl::flat_hash_map opaque_to_channel_ + GUARDED_BY(channel_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ChannelTracker); }; diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 8f08c244908efb823b3870c19bdc3491fa87d44f..653f4555a77cc82e91fb1cd26206b93826375732 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -98,10 +98,17 @@ Compiler::GetPlatformCompilers() { auto* factories = GetPlatformCompilerFactories(); auto it = factories->find(platform->id()); if (it == factories->end()) { + string hint; + if (platform->Name() == "Host") { + hint = " (hint: try linking in tensorflow/compiler/jit:xla_cpu_jit)"; + } else if (platform->Name() == "CUDA") { + hint = " (hint: try linking in tensorflow/compiler/jit:xla_gpu_jit)"; + } + return NotFound( "could not find registered compiler for platform %s -- check " - "target linkage", - platform->Name()); + "target linkage%s", + platform->Name(), hint); } // And then we invoke the factory, placing the result into the mapping. diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index efc893818d03a20d6bd65b7dc1da72ea5da5ceb0..92d1ca4ba5da802a5f1c544017ac52dda38e9b1d 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -42,8 +42,8 @@ void ComputationLayout::SetToDefaultLayout() { } bool ComputationLayout::LayoutIsSet() const { - return std::all_of(parameter_layouts_.begin(), parameter_layouts_.end(), - [](const ShapeLayout& s) { return s.LayoutIsSet(); }) && + return absl::c_all_of(parameter_layouts_, + [](const ShapeLayout& s) { return s.LayoutIsSet(); }) && result_layout_.LayoutIsSet(); } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc similarity index 58% rename from tensorflow/compiler/xla/service/convolution_feature_group_converter.cc rename to tensorflow/compiler/xla/service/convolution_group_converter.cc index 09c3f32860b3176ee5afbb147872ddafc51af256..f11f9e5fc2949a92f83ff66506a9b162ffda1c92 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include #include @@ -50,8 +50,12 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { Status HandleConvolution(HloInstruction* convolution) override; + Status HandleBatchGroupCount(HloInstruction* convolution); + // Runs the visitor on a computation. static bool Run(HloComputation* computation, + std::function is_cost_viable, + bool convert_batch_groups_only, bool canonicalize_depthwise_filter); // Returns whether any convolution ops were rewritten. @@ -60,10 +64,15 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { ~ConvolutionVisitor() override = default; private: - explicit ConvolutionVisitor(HloComputation* computation, - bool canonicalize_depthwise_filter = false) + explicit ConvolutionVisitor( + HloComputation* computation, + std::function is_cost_viable, + bool convert_batch_groups_only, + bool canonicalize_depthwise_filter = false) : computation_(computation), - filter_expansion_(!canonicalize_depthwise_filter) {} + filter_expansion_(!canonicalize_depthwise_filter), + convert_batch_groups_only_(convert_batch_groups_only), + is_cost_viable_(is_cost_viable) {} // Current HloComputation instance the ConvolutionVisitor is traversing. HloComputation* computation_; @@ -73,11 +82,21 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { // Whether filter expansion is required. bool filter_expansion_; + + // Decides whether to convert batch groups or feature groups. + bool convert_batch_groups_only_; + + // std::function(int64, int64)> chunk_fetcher + std::function is_cost_viable_; }; -bool ConvolutionVisitor::Run(HloComputation* computation, - bool canonicalize_depthwise_filter) { - ConvolutionVisitor visitor(computation, canonicalize_depthwise_filter); +bool ConvolutionVisitor::Run( + HloComputation* computation, + std::function is_cost_viable, + bool convert_batch_groups_only, bool canonicalize_depthwise_filter) { + ConvolutionVisitor visitor(computation, is_cost_viable, + convert_batch_groups_only, + canonicalize_depthwise_filter); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -176,18 +195,143 @@ HloInstruction* GetExpandedFilterMask( predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2)); } -Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { - int64 group_count = convolution->feature_group_count(); - if (group_count == 1) { +// This function handles batch_group_counts which are relevant only for +// depthwise backprop filter convolutions. +Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) { + auto dim_numbers = convolution->convolution_dimension_numbers(); + auto activation = convolution->mutable_operand(0); + auto filter = convolution->mutable_operand(1); + int64 batch_group_count = convolution->batch_group_count(); + + if (batch_group_count == 1) { return Status::OK(); } - auto filter = convolution->mutable_operand(1); - changed_ = true; + + VLOG(2) << "Dealing with batch_group_count " << batch_group_count + << " for convolution " << convolution->ToString() << "\n"; + + auto add = [&](std::unique_ptr inst) { + return computation_->AddInstruction(std::move(inst)); + }; + + int64 input_batch_dimension = dim_numbers.input_batch_dimension(); + int64 output_batch_dimension = dim_numbers.output_batch_dimension(); + int64 output_feature_dimension = dim_numbers.output_feature_dimension(); + + int64 input_batch = activation->shape().dimensions(input_batch_dimension); + + // We are not yet supporting batch_group of sizes greater than 1. + TF_RET_CHECK(input_batch == batch_group_count); + + if (!is_cost_viable_(convolution) || filter_expansion_) { + // We first obtain the expanded the filter (which is the convolution + // output). The batch dimension is the expanded one (which originally + // represents kernel input feature dimension). We mask the filter to zero + // out the expanded regions. Next we reduce the filter in the batch + // dimension to obtain the original filter size. + + HloInstruction* filter_mask = + GetExpandedFilterMask(convolution->shape(), output_batch_dimension, + output_feature_dimension, batch_group_count, add); + auto expanded_filter_shape = ExpandedFilterShape( + convolution->shape(), batch_group_count, output_batch_dimension); + + auto new_convolution = add(HloInstruction::CreateConvolve( + expanded_filter_shape, activation, filter, + /*feature_group_count=*/1, /*batch_group_count=*/1, + convolution->window(), dim_numbers, convolution->precision_config())); + + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); + auto zero_filter = + add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); + + auto new_filter = add(HloInstruction::CreateTernary( + expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution, + zero_filter)); + + PrimitiveType reduce_type = new_filter->shape().element_type(); + auto reduce_window_shape = new_convolution->shape(); + reduce_window_shape.set_dimensions(output_batch_dimension, 1); + + // Ensure that data input to reduce window uses at least 32 bits. + if (primitive_util::BitWidth(reduce_type) < primitive_util::BitWidth(F32)) { + reduce_type = F32; + reduce_window_shape.set_element_type(F32); + Shape convert_shape = new_filter->shape(); + convert_shape.set_element_type(F32); + new_filter = + add(HloInstruction::CreateConvert(convert_shape, new_filter)); + } + + auto zero_literal = LiteralUtil::Zero(reduce_type); + auto zero_scalar = + add(HloInstruction::CreateConstant(std::move(zero_literal))); + + auto reduce_function = [&]() -> HloComputation* { + HloComputation::Builder b("add_computation"); + Shape shape = ShapeUtil::MakeShape(reduce_type, {}); + auto lhs = + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); + auto rhs = + b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); + auto scalar_op = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs)); + return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + }; + + // Create the reduce window. + Window window; + for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) { + auto* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + if (i == output_batch_dimension) { + dim->set_stride(batch_group_count); + dim->set_size(batch_group_count); + } else { + dim->set_stride(1); + dim->set_size(1); + } + } + auto reduce_window = add(HloInstruction::CreateReduceWindow( + reduce_window_shape, new_filter, zero_scalar, window, + reduce_function())); + + Shape convert_back_shape = reduce_window->shape(); + convert_back_shape.set_element_type(activation->shape().element_type()); + + // Convert reduced data back to the original data type. + auto reduce_window_converted = + HloInstruction::CreateConvert(convert_back_shape, reduce_window); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(reduce_window_converted))); + changed_ = true; + } + + return Status::OK(); +} + +Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { + if (convert_batch_groups_only_) { + return HandleBatchGroupCount(convolution); + } + auto add = [&](std::unique_ptr inst) { return computation_->AddInstruction(std::move(inst)); }; + int64 group_count = convolution->feature_group_count(); + if (group_count == 1) { + return Status::OK(); + } + + changed_ = true; auto dim_numbers = convolution->convolution_dimension_numbers(); + auto filter = convolution->mutable_operand(1); int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension(); int64 group_size = filter->shape().dimensions(kernel_input_feature_dim); int64 kernel_output_feature_dim = @@ -205,38 +349,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { // If the code generator handles depthwise separable convolutions // inherently, then no filter expansion is needed. if (!filter_expansion_ && depthwise_separable) { - const int64 old_kernel_input_feature_dimension = - dim_numbers.kernel_input_feature_dimension(); - const int64 old_kernel_output_feature_dimension = - dim_numbers.kernel_output_feature_dimension(); - - // For depthwise convolutions, we want the kernel input feature dimension - // to be smaller than the output feature dimension. If that's not the - // case, we swap the dimensions. - if (old_kernel_input_feature_dimension > - old_kernel_output_feature_dimension) { - Shape reshaped_filter_shape = filter->shape(); - auto& dimensions = *reshaped_filter_shape.mutable_dimensions(); - std::swap(dimensions[old_kernel_input_feature_dimension], - dimensions[old_kernel_output_feature_dimension]); - - auto reshaped_filter = - add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); - - dim_numbers.set_kernel_input_feature_dimension( - old_kernel_output_feature_dimension); - - dim_numbers.set_kernel_output_feature_dimension( - old_kernel_input_feature_dimension); - - auto new_convolution = HloInstruction::CreateConvolve( - convolution->shape(), convolution->mutable_operand(0), - reshaped_filter, group_count, convolution->window(), dim_numbers, - convolution->precision_config()); - - TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - convolution, std::move(new_convolution))); - } + changed_ = false; return Status::OK(); } // We want to repeat 'filter' in the 'input_feature_dim' dimension @@ -265,136 +378,79 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { auto new_convolution = HloInstruction::CreateConvolve( convolution->shape(), convolution->mutable_operand(0), new_filter, - /*feature_group_count=*/1, convolution->window(), dim_numbers, - convolution->precision_config()); + /*feature_group_count=*/1, /*batch_group_count=*/1, + convolution->window(), dim_numbers, convolution->precision_config()); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(new_convolution))); } else { int64 activation_input_feature_dim = dim_numbers.input_feature_dimension(); - auto activation = convolution->mutable_operand(0); int64 output_feature = filter->shape().dimensions(kernel_output_feature_dim); - int64 input_feature = - activation->shape().dimensions(activation_input_feature_dim); - // If group_count == output_feature, then we map those grouped convolutions - // onto depthwise convolution + reduce. E.g., we would turn + // onto depthwise convolution. This is done by adding an additional spatial + // dimension to the activations, kernel, and the output. + // E.g., we would turn // [2, 12]{B, IF} conv [3, 4]{IF, OF} into - // [2, 12]{B, IF} depth conv [1, 12]{IF, OF}, and then use a reduce window - // of {1, 3} on the generated [2, 12] output to produce the final result of - // [2, 4]. + // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the + // additional spatial dimension. The generated convolution output will be + // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}. + if (group_count == output_feature && !filter_expansion_) { - Shape reshaped_filter_shape = filter->shape(); + auto filter = convolution->mutable_operand(1); + auto activation = convolution->mutable_operand(0); - if (kernel_input_feature_dim < kernel_output_feature_dim) { - // Transpose IF and OF on the kernel. - std::vector filter_dims; - for (int64 i = 0; i < dim_numbers.kernel_spatial_dimensions().size(); - ++i) { - filter_dims.push_back(dim_numbers.kernel_spatial_dimensions(i)); - } - filter_dims.push_back(kernel_output_feature_dim); - filter_dims.push_back(kernel_input_feature_dim); - - Shape transposed_filter = filter->shape(); - auto& dimensions = *transposed_filter.mutable_dimensions(); - std::swap(dimensions[kernel_input_feature_dim], - dimensions[kernel_output_feature_dim]); - - filter = add(HloInstruction::CreateTranspose(transposed_filter, filter, - filter_dims)); - } else { - // For depthwise convolutions, we want the kernel input feature - // dimension to be smaller than the output feature dimension. If that's - // not the case, we swap the dimensions. + // Add spatial dimension to the activation, and reshape. + Shape reshaped_activation_shape = activation->shape(); + ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape); - auto& dimensions = *reshaped_filter_shape.mutable_dimensions(); - std::swap(dimensions[kernel_input_feature_dim], - dimensions[kernel_output_feature_dim]); + int64 new_spatial_dim = reshaped_activation_shape.dimensions().size() - 1; - dim_numbers.set_kernel_input_feature_dimension( - kernel_output_feature_dim); + reshaped_activation_shape.set_dimensions(activation_input_feature_dim, + group_count); + activation = add( + HloInstruction::CreateReshape(reshaped_activation_shape, activation)); - dim_numbers.set_kernel_output_feature_dimension( - kernel_input_feature_dim); - std::swap(kernel_output_feature_dim, kernel_input_feature_dim); - } + // Add spatial dimension to the filter, and reshape. + Shape reshaped_filter_shape = filter->shape(); + ShapeUtil::AppendMajorDimension(1, &reshaped_filter_shape); - reshaped_filter_shape.set_dimensions(kernel_input_feature_dim, 1); - reshaped_filter_shape.set_dimensions(kernel_output_feature_dim, - group_count * group_size); - auto reshaped_filter = + filter = add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); - Shape reshaped_convolution_shape = convolution->shape(); - reshaped_convolution_shape.set_dimensions( - dim_numbers.output_feature_dimension(), group_count * group_size); - auto new_convolution = add(HloInstruction::CreateConvolve( - reshaped_convolution_shape, convolution->mutable_operand(0), - reshaped_filter, /*feature_group_count=*/input_feature, - convolution->window(), dim_numbers, convolution->precision_config())); - - // Create the reduce window. - Window window; - for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) { - auto* dim = window.add_dimensions(); - dim->set_padding_low(0); - dim->set_padding_high(0); - dim->set_window_dilation(1); - dim->set_base_dilation(1); - if (i == dim_numbers.output_feature_dimension()) { - dim->set_stride(group_size); - dim->set_size(group_size); - } else { - dim->set_stride(1); - dim->set_size(1); - } - } - - auto reduce_window_shape = new_convolution->shape(); - reduce_window_shape.set_dimensions(dim_numbers.output_feature_dimension(), - group_count); - - auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(F32)); - auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - - auto reduce_function = [&]() -> HloComputation* { - HloComputation::Builder b("add_computation"); - Shape shape = ShapeUtil::MakeShape(F32, {}); - auto lhs = - b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); - auto rhs = - b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs)); - return computation_->parent()->AddEmbeddedComputation( - b.Build(scalar_op)); - }; - - // Ensure that data input to reduce window is of type F32. - if (primitive_util::BitWidth(new_convolution->shape().element_type()) < - primitive_util::BitWidth(F32)) { - Shape convert_shape = new_convolution->shape(); - convert_shape.set_element_type(F32); - new_convolution = add(HloInstruction::CreateBitcastConvert( - convert_shape, new_convolution)); - } + Shape new_output_shape = convolution->shape(); + ShapeUtil::AppendMajorDimension(1, &new_output_shape); + + // Edit convolution dimension numbers. Note that kernel_input_feature_dim + // now becomes a spatial dimension, and the newly added dimension of size + // 1 is the new kernel_input_feature_dim. + dim_numbers.add_input_spatial_dimensions(new_spatial_dim); + dim_numbers.add_kernel_spatial_dimensions(kernel_input_feature_dim); + dim_numbers.set_kernel_input_feature_dimension(new_spatial_dim); + dim_numbers.add_output_spatial_dimensions(new_spatial_dim); + + // Add window for the new spatial dimension. + Window new_window = convolution->window(); + auto* dim = new_window.add_dimensions(); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + dim->set_stride(1); + dim->set_size(group_size); - auto reduce_window = add(HloInstruction::CreateReduceWindow( - reduce_window_shape, new_convolution, zero, window, - reduce_function())); + auto new_convolution = add(HloInstruction::CreateConvolve( + new_output_shape, activation, filter, group_count, + /*batch_group_count=*/1, new_window, dim_numbers, + convolution->precision_config())); - Shape convert_back_shape = reduce_window->shape(); - convert_back_shape.set_element_type(activation->shape().element_type()); + // Delete the extra spatial dimension, and reshape. + Shape reshaped_convolution_shape = + ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape()); + auto reshaped_convolution = HloInstruction::CreateReshape( + reshaped_convolution_shape, new_convolution); - // Convert reduced data back to the original data type. - auto reduce_window_converted = HloInstruction::CreateBitcastConvert( - convert_back_shape, reduce_window); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - convolution, std::move(reduce_window_converted))); + convolution, std::move(reshaped_convolution))); } else { // The filter expansion mechanism adds zeroes in the kernel. @@ -462,7 +518,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { auto new_convolution = add(HloInstruction::CreateConvolve( conv_slice_shape, activation_slice, filter_slice, - /*feature_group_count=*/1, convolution->window(), dim_numbers, + /*feature_group_count=*/1, /*batch_group_count=*/1, + convolution->window(), dim_numbers, convolution->precision_config())); sliced_convolutions.push_back(new_convolution); @@ -480,17 +537,19 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { } // namespace -StatusOr ConvolutionFeatureGroupConverter::Run(HloModule* module) { - XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), before:\n" + - module->ToString()); +StatusOr ConvolutionGroupConverter::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "ConvolutionGroupConverter::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (ConvolutionVisitor::Run(comp, filter_expansion_)) { + if (ConvolutionVisitor::Run(comp, is_cost_viable_, + convert_batch_groups_only_, + filter_expansion_)) { changed = true; } } - XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), after:\n" + - module->ToString()); + XLA_VLOG_LINES( + 2, "ConvolutionGroupConverter::Run(), after:\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_group_converter.h similarity index 58% rename from tensorflow/compiler/xla/service/convolution_feature_group_converter.h rename to tensorflow/compiler/xla/service/convolution_group_converter.h index cb6bc04c00a2ff10f970da2a07fb540a561dad5a..1caf1841119a965044502435fe0f5b38ca94f6a5 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_group_converter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -25,23 +25,34 @@ namespace xla { // A pass which rewrites convolutions with feature_group_count > 1 into // convolutions with feature_group_count = 1. -class ConvolutionFeatureGroupConverter : public HloModulePass { +class ConvolutionGroupConverter : public HloModulePass { public: - ConvolutionFeatureGroupConverter(bool canonicalize_depthwise_filter = false) - : filter_expansion_(canonicalize_depthwise_filter) {} + ConvolutionGroupConverter(std::function is_cost_viable, + bool convert_batch_groups_only, + bool canonicalize_depthwise_filter = false) + : is_cost_viable_(is_cost_viable), + convert_batch_groups_only_(convert_batch_groups_only), + filter_expansion_(canonicalize_depthwise_filter) {} absl::string_view name() const override { - return "convolution-feature-group-converter"; + return "convolution-group-converter"; } // Run convolution rewriting on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + // Lambda containing cost model that decides whether to expand + // batch_group_count. + std::function is_cost_viable_; + + // Decides whether to convert batch groups or feature groups. + bool convert_batch_groups_only_; + // Tells whether filter expansion is required. bool filter_expansion_; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc similarity index 68% rename from tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc rename to tensorflow/compiler/xla/service/convolution_group_converter_test.cc index e6bf2143a21bd5001d3530fe8727c88504be1d43..9cee3eda95252d6c7d725fbb03030bd58f52e71f 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include #include @@ -30,10 +30,10 @@ limitations under the License. namespace xla { namespace { -using ConvolutionFeatureGroupConverterTest = HloTestBase; +using ConvolutionGroupConverterTest = HloTestBase; namespace op = testing::opcode_matchers; -TEST_F(ConvolutionFeatureGroupConverterTest, +TEST_F(ConvolutionGroupConverterTest, ConvertFeatureGroupCountEqualToInputFeatureDim) { string hlo_string = R"(HloModule Convolve1D1Window_0_module @@ -49,7 +49,8 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2 auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - ConvolutionFeatureGroupConverter converter; + ConvolutionGroupConverter converter(nullptr, /*convert_batch_groups_only=*/ + false); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure the convolution is converted to one with feature_group_count = 1. @@ -63,7 +64,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2 op::Broadcast(op::Constant()))); } -TEST_F(ConvolutionFeatureGroupConverterTest, +TEST_F(ConvolutionGroupConverterTest, ConvertFeatureGroupCountDivisorOfInputFeatureDim) { string hlo_string = R"(HloModule Convolve1D1Window_0_module @@ -79,7 +80,8 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2 auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - ConvolutionFeatureGroupConverter converter; + ConvolutionGroupConverter converter(nullptr, /*convert_batch_groups_only=*/ + false); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure the convolution is replaced with a concatenate. @@ -92,5 +94,32 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2 EXPECT_EQ(root->operand(1)->feature_group_count(), 1); } +TEST_F(ConvolutionGroupConverterTest, + ConvertBatchGroupCountEqualToInputBatchDim) { + string hlo_string = R"(HloModule Convolve1D1Window_0_module + +ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16,19,19,512]{3,2,1,0}) -> f32[3,3,512,1]{3,2,1,0} { + %input = f32[16,19,19,512]{3,2,1,0} parameter(0) + %filter = f32[16,19,19,512]{3,2,1,0} parameter(1) + ROOT %convolution = f32[3,3,512,1]{3,2,1,0} convolution(f32[16,19,19,512]{3,2,1,0} %input, f32[16,19,19,512]{3,2,1,0} %filter), window={size=19x19 pad=1_1x1_1}, dim_labels=f01b_i01o->01fb, batch_group_count=512 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + auto cost_model = [](HloInstruction* conv) { return false; }; + ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/ + true); + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + + // Verify that the convolution is replaced by a convert. + EXPECT_EQ(root->opcode(), HloOpcode::kConvert); + // Make sure the convert is being fed by a reduce window. + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kReduceWindow); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index df6059663876dfde71f4c75d3931b3d2de72c1df..5e26a63cebfa9b2e50f4b13335c10c246999d4df 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -349,11 +349,12 @@ Status AddCopiesForAliasedInputOutputs(HloModule* module) { ShapeTree param_indices_to_copy(param->shape()); module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { - if (param_number == param->parameter_number()) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + if (alias.parameter_number == param->parameter_number()) { param_has_alias = true; - *(param_indices_to_copy.mutable_element(param_index)) = true; + *(param_indices_to_copy.mutable_element(alias.parameter_index)) = + true; *(output_indices_to_copy.mutable_element(output_index)) = true; } }); @@ -395,13 +396,14 @@ Status AddCopiesForAliasedInputOutputs(HloModule* module) { // Add control dependencies between the input/output copies. TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& input_index) -> Status { - if (!copied_parameters[param_number]) { + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) -> Status { + if (!copied_parameters[alias.parameter_number]) { return Status::OK(); } HloInstruction* from = - copied_parameters[param_number]->element(input_index); + copied_parameters[alias.parameter_number]->element( + alias.parameter_index); HloInstruction* to = output_copy_tree.element(output_index); TF_RET_CHECK(from != nullptr); @@ -522,7 +524,7 @@ class CopyRemover { // between copies added around aliased operations (kWhile) guarantees // this strict order. for (const HloValue* value_a : buffer.values()) { - if (ShapeUtil::IsToken(value_a->shape())) { + if (value_a->shape().IsToken()) { // Token values have no representation and cannot interfere. continue; } @@ -539,10 +541,9 @@ class CopyRemover { } std::vector values = buffer.values(); - std::sort(values.begin(), values.end(), - [this](const HloValue* a, const HloValue* b) { - return ordering_.IsDefinedBefore(*a, *b); - }); + absl::c_sort(values, [this](const HloValue* a, const HloValue* b) { + return ordering_.IsDefinedBefore(*a, *b); + }); // Create a list containing all of the values in the buffer. AddValueList(values, &value_to_node); @@ -842,12 +843,11 @@ class CopyRemover { copy_value_node->next->prev = operand_node; // Patch up uses. Remove use of copy from operand_node uses. - auto it = - std::find_if(operand_node->uses.begin(), operand_node->uses.end(), - [copy_value_node](const HloUse* use) { - return use->instruction == - copy_value_node->value->defining_instruction(); - }); + auto it = absl::c_find_if( + operand_node->uses, [copy_value_node](const HloUse* use) { + return use->instruction == + copy_value_node->value->defining_instruction(); + }); CHECK(it != operand_node->uses.end()); operand_node->uses.erase(it); diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index e4e9d7ba05c115be9dd0eb53ebd7de208d514efb..4391bdcba532661a0fde789e2c4ed324c40bcd32 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1376,9 +1376,11 @@ TEST_F(CopyInsertionTest, CrossingParameters) { builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 4); @@ -1409,9 +1411,11 @@ TEST_F(CopyInsertionTest, ParametersAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1475,7 +1479,8 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -1516,7 +1521,8 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1557,7 +1563,8 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); module->AddEntryComputation(builder.Build()); ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 0); @@ -1848,8 +1855,7 @@ ENTRY %TokensShouldNotBeCopied () -> s32[] { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - HloRunner::CreateModuleFromString( - module_string, GetDebugOptionsForTest())); + ParseAndReturnVerifiedModule(module_string)); InsertCopies(module.get()); // There should be no copies added because tokens should not be copied. @@ -2112,8 +2118,7 @@ ENTRY TestComputation { ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 } )"; - auto module_or_status = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); auto module = module_or_status.ConsumeValueOrDie(); InsertCopies(module.get()); } @@ -2213,8 +2218,7 @@ ENTRY TestComputation { ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 } )"; - auto module_or_status = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); auto module = module_or_status.ConsumeValueOrDie(); InsertCopies(module.get()); } @@ -2231,7 +2235,7 @@ cond.inner { body.inner { param.body.inner = pred[] parameter(0) - ROOT neg = pred[] negate(param.body.inner) + ROOT not = pred[] not(param.body.inner) } cond.outer { @@ -2248,9 +2252,8 @@ ENTRY TestComputation { ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); InsertCopies(module.get()); // There should only be a single copy inserted, and it's in the entry diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index ce4c2a9cc69240b9565b35a3f2504d7fc9373917..d4535b204d7f3ad8d4e24beea5d0dd79e7a15ab0 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -1,6 +1,14 @@ # Description: # LLVM-based CPU backend for XLA. +load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") +load( + "//third_party/mkl:build_defs.bzl", + "mkl_deps", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load(":build_defs.bzl", "runtime_copts") + licenses(["notice"]) # Apache 2.0 package( @@ -14,15 +22,6 @@ package_group( ], ) -load(":build_defs.bzl", "runtime_copts") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") -load( - "//third_party/mkl:build_defs.bzl", - "mkl_deps", -) - # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", @@ -95,6 +94,7 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:scatter_expander", @@ -112,8 +112,9 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", - "//tensorflow/compiler/xla/service:convolution_feature_group_converter", + "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -133,6 +134,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:sort_simplifier", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -241,6 +243,7 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor/host:host_stream", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -364,15 +367,33 @@ cc_library( ], ) +cc_library( + name = "tiled_dot_emitter", + srcs = ["tiled_dot_emitter.cc"], + hdrs = ["tiled_dot_emitter.h"], + deps = [ + ":vector_support_library", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "dot_op_emitter", srcs = ["dot_op_emitter.cc"], - hdrs = ["dot_op_emitter.h"], + hdrs = [ + "dot_op_emitter.h", + ], deps = [ ":cpu_options", ":cpu_runtime", ":ir_emission_utils", ":target_machine_features", + ":tiled_dot_emitter", ":vector_support_library", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -380,6 +401,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", @@ -572,6 +594,7 @@ cc_library( ":runtime_matvec", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//third_party/eigen3", ], ) @@ -630,6 +653,7 @@ cc_library( deps = [ ":runtime_matvec", "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//third_party/eigen3", ], ) @@ -1005,7 +1029,6 @@ tf_cc_test( size = "small", srcs = ["cpu_eigen_tensor_alignment_test.cc"], deps = [ - ":dot_op_emitter", ":ir_emission_utils", ":target_machine_features_fake", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 796a7cf94d02b0ad42366387a9d3f8d589b8840a..414eacddfc7ba3c295c027c64c445a2046235d36 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -66,9 +66,14 @@ class FilteredPassManager : public llvm::legacy::PassManager { explicit FilteredPassManager(bool disable_expensive_passes) : disable_expensive_passes_(disable_expensive_passes) {} void add(llvm::Pass* p) override { + llvm::StringRef PassName = p->getPassName(); + if (PassName.contains("Warn about non-applied transformations")) { + delete p; + return; + } if (disable_expensive_passes_) { - llvm::StringRef PassName = p->getPassName(); if (PassName.contains("Unroll loops")) { + delete p; return; } } diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 2d9978404cc9ec1e40fc61aaf794a8f1f06050bb..8e55267a67d330e7e721f9b5fb25451357a49a9d 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -132,7 +132,8 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { HloInstruction* new_conv = module->entry_computation()->AddInstruction( HloInstruction::CreateConvolve( new_conv_shape, new_input, new_kernel, hlo->feature_group_count(), - hlo->window(), new_dnums, hlo->precision_config())); + hlo->batch_group_count(), hlo->window(), new_dnums, + hlo->precision_config())); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index c58175428fea6a2d38253c35de598b99a4281bf1..02085108a081358cd4f8aed6dc12557cbd8eea85 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -84,8 +84,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), - input, kernel, /*feature_group_count=*/1, conv_window_, dnums, - DefaultPrecisionConfig(2))); + input, kernel, /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window_, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -147,8 +147,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), - input, kernel, /*feature_group_count=*/1, conv_window_, dnums, - DefaultPrecisionConfig(2))); + input, kernel, /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window_, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 2bf24c15c1f050b200b1d9af2d95286f9a9dbe4c..eafda68510d93ee54f2aead60a84f3e97b3fe1f4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -51,7 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" -#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" @@ -69,6 +69,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -92,6 +93,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/scatter_expander.h" +#include "tensorflow/compiler/xla/service/sort_simplifier.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -244,21 +246,30 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloPassPipeline pipeline("HLO passes through layout assignment"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pipeline.AddPass(); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); - pipeline.AddPass(); pipeline.AddPass(); // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(/*decompose_batch_dot=*/false); + auto cost_model = [](HloInstruction* conv) { + // We need a cost model for CPUs. Currently, do nothing. + return false; + }; + pipeline.AddPass( + cost_model, + /*convert_batch_groups_only=*/true); + pipeline.AddPass( + cost_model, + /*convert_batch_groups_only=*/false); pipeline.AddPass(target_machine_features); { auto& pass = @@ -270,10 +281,11 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - AlgebraicSimplifierOptions options( - [](const Shape&, const Shape&) { return false; }); + pipeline.AddPass(); + AlgebraicSimplifierOptions options; options.set_enable_dot_strength_reduction(false); pass.AddPass(options); + pass.AddPass(); pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO @@ -293,7 +305,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass( [&](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot, *target_machine_features) + return DotImplementationCanHandleTranspose(dot, + *target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -336,8 +349,7 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( pass.AddInvariantChecker( /*layout_sensitive=*/true, /*allow_mixed_precision=*/false); - AlgebraicSimplifierOptions options( - [](const Shape&, const Shape&) { return true; }); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); options.set_enable_dot_strength_reduction(false); pass.AddPass>(options); @@ -497,7 +509,7 @@ Status CreateHloProfilingArtifacts( auto shape_size_bytes = [](const Shape& shape) { // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return static_cast(sizeof(void*)); } return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); @@ -635,18 +647,17 @@ StatusOr> CpuCompiler::RunBackend( .EmitComputation( embedded_computation, embedded_computation->name(), /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + schedule.sequence(embedded_computation).instructions()) .status()); } string function_name_prefix = entry_computation->name().empty() ? "__compute" : entry_computation->name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation( - entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &schedule.sequence(entry_computation).instructions())); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + schedule.sequence(entry_computation).instructions())); string function_name = [&]() { llvm::SmallVector function_name_vector; @@ -835,7 +846,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, .EmitComputation( embedded_computation, embedded_computation->name(), /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + schedule.sequence(embedded_computation).instructions()) .status()); } const string& entry_point_name = options.entry_point_name(); @@ -843,7 +854,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, ir_emitter.EmitComputation( computation, entry_point_name, /*is_top_level_computation=*/true, - &schedule.sequence(computation).instructions())); + schedule.sequence(computation).instructions())); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc index 8727c72b6e42517b1859e98ecadb41bbceed761c..485769a373acf5ae70c471b1a5dfcfb20ff772ef 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -28,37 +27,6 @@ namespace { class CpuEigenTensorAlignmentTest : public ::testing::Test {}; -TEST_F(CpuEigenTensorAlignmentTest, EigenDotAlignment) { - string hlo_string = R"( -HloModule DotOperation - -ENTRY DotOperation { - arg0 = f32[5,256] parameter(0) - arg1 = f32[256,1024] parameter(1) - ROOT dot = f32[5,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_string)); - - HloInstruction* dot = module->entry_computation()->root_instruction(); - - TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( - [](int64 size) { return 1; }); - - EXPECT_FALSE( - PotentiallyImplementedAsEigenDot(*dot, target_machine_with_no_alignment)); - - TargetMachineFeaturesWithFakeAlignmentLogic - target_machine_with_full_alignment([](int64 size) { - return TargetMachineFeatures::kEigenExpectedTensorAlignment; - }); - - EXPECT_TRUE(PotentiallyImplementedAsEigenDot( - *dot, target_machine_with_full_alignment)); -} - TEST_F(CpuEigenTensorAlignmentTest, EigenConvAlignment) { string hlo_string = R"( HloModule ConvOperation diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 818b2b0d0db2893e11fa46c7867e6c74bbbb6905..23d0af34233858515af21df5e92346742a5b5dc3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -213,6 +213,8 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); + const HloInputOutputAliasConfig& input_output_alias = + module().input_output_alias_config(); // Move OwningDeviceMemory values which contain the array(s) of the result // into the respective location in ScopedShapedBuffer which is returned to the @@ -232,12 +234,31 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( TF_ASSIGN_OR_RETURN( const BufferAllocation::Slice slice, this->assignment_->GetUniqueSlice(src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - const BufferAllocation::Index buffer_index = slice.index(); OwningDeviceMemory& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *device_memory = buffer.Forget(); + if (!slice.allocation()->is_entry_computation_parameter()) { + // If the buffer coming out of the result is from a parameter, the + // owning buffer will be null, and that means the caller aliased some + // parameter buffer to an output one (via the + // HloInputOutputAliasConfig API). If that is the case, the caller + // will receive a partially complete scoped shaped buffer, which they + // will have to fill up on return. Unfortunately the interface to the + // execute APIs are ShapedBuffer pointer based, which assumes caller + // ownership, and hence a buffer coming from there cannot be part of + // the new ScopedShapedBuffer we create for the result (which assumes + // ownership). + *device_memory = buffer.Forget(); + } else { + auto output_alias = input_output_alias.GetAliasedOutput( + slice.allocation()->parameter_number(), + slice.allocation()->param_shape_index()); + CHECK(output_alias) + << "Ouput buffer is coming from parameter " + << slice.allocation()->parameter_number() << " at index " + << slice.allocation()->param_shape_index() + << ", but no alias exists"; + CHECK_EQ(*output_alias, index); + } return Status::OK(); })); return std::move(result_buffer); @@ -326,7 +347,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return sizeof(void*); } return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 527df0bd1c23bba74f32226e5622fed32f7dcf84..c4bde837e57e82584c2a007858ed8d55608acd3c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -332,7 +332,7 @@ TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {8}); - Shape starts_shape = ShapeUtil::MakeShape(F32, {2}); + Shape starts_shape = ShapeUtil::MakeShape(F32, {}); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8}); Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8}); Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -340,13 +340,15 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { HloInstruction::CreateParameter(0, param_shape, "param")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, starts_shape, "starts")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); HloInstruction* broadcast2 = builder.AddInstruction( HloInstruction::CreateBroadcast(broadcast_shape, param0, {1})); HloInstruction* reshape3 = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, broadcast2)); HloInstruction* dynamic_slice4 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, reshape3, param1, {4, 4})); + dynamic_slice_shape, reshape3, {param1, param2}, {4, 4})); builder.AddInstruction(HloInstruction::CreateUnary( dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); @@ -356,7 +358,8 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { RunFusionAndCheckOpcodesWereFused( module.get(), {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape, - HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter}); + HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter}); } TEST_F(OpcodeFusionTest, Broadcast_Negate) { @@ -381,14 +384,14 @@ TEST_F(OpcodeFusionTest, Broadcast_Negate) { TEST_F(OpcodeFusionTest, DynamicSlice_Negate) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {4}); - Shape slice_shape = ShapeUtil::MakeShape(F32, {1}); + Shape slice_shape = ShapeUtil::MakeShape(F32, {}); Shape result_shape = ShapeUtil::MakeShape(F32, {2}); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, slice_shape, "starts")); HloInstruction* dynamic_slice2 = builder.AddInstruction( - HloInstruction::CreateDynamicSlice(result_shape, param0, param1, {2})); + HloInstruction::CreateDynamicSlice(result_shape, param0, {param1}, {2})); builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, dynamic_slice2)); @@ -548,28 +551,36 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); + std::vector slice_indices, update_indices; + for (int i = 0; i < 3; ++i) { + slice_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + 1 + i, ShapeUtil::MakeShape(U32, {}), "slice_indices"))); + update_indices.push_back( + builder.AddInstruction(HloInstruction::CreateParameter( + 5 + i, ShapeUtil::MakeShape(U32, {}), "update_indices"))); + } HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( slice_shape, builder.AddInstruction( HloInstruction::CreateParameter(0, full_shape, "slice_from")), - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + slice_indices, /*slice_sizes=*/{10, 1, 1000})); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_shape, builder.AddInstruction( - HloInstruction::CreateParameter(2, full_shape, "to_update")), - slice, - builder.AddInstruction(HloInstruction::CreateParameter( - 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); + HloInstruction::CreateParameter(4, full_shape, "to_update")), + slice, update_indices)); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( - module.get(), {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, - HloOpcode::kParameter, HloOpcode::kParameter, - HloOpcode::kParameter, HloOpcode::kParameter}); + module.get(), + {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}); } TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { @@ -578,49 +589,40 @@ TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); - auto loop_idx = builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(S32, {1}), - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(S32, {}), "param0")))); - + auto loop_idx = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {}), "param0")); auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(S32, {1}), "param1")); - auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(S32, {5}), - {loop_idx, param1, param1, param1, param1}, /*dimension=*/0)); + 1, ShapeUtil::MakeShape(S32, {}), "param1")); - auto idx_choice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ShapeUtil::MakeShape(S32, {1}), - builder.AddInstruction(HloInstruction::CreateParameter( - 2, ShapeUtil::MakeShape(S32, {4}), "param2")), - loop_idx, - /*slice_sizes=*/{1})); - - PaddingConfig padding_config; - padding_config.add_dimensions()->set_edge_padding_high(4); - auto pad = builder.AddInstruction(HloInstruction::CreatePad( - ShapeUtil::MakeShape(S32, {5}), idx_choice, - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), - padding_config)); + auto idx_choice = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {}), + builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(S32, {4}), "param2")), + {loop_idx}, + /*slice_sizes=*/{1})))); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( ShapeUtil::MakeShape(F32, {1, 100, 10, 100, 50}), builder.AddInstruction(HloInstruction::CreateParameter( 3, ShapeUtil::MakeShape(F32, {100, 100, 10, 100, 50}), "param3")), - pad, /*slice_sizes=*/{1, 100, 10, 100, 50})); + {idx_choice, zero, zero, zero, zero}, + /*slice_sizes=*/{1, 100, 10, 100, 50})); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_shape, builder.AddInstruction( HloInstruction::CreateParameter(4, full_shape, "param4")), - slice, concat)); + slice, {loop_idx, param1, param1, param1, param1})); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( module.get(), - {HloOpcode::kConcatenate, HloOpcode::kPad, HloOpcode::kDynamicSlice, - HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, + {HloOpcode::kDynamicSlice, HloOpcode::kDynamicSlice, + HloOpcode::kDynamicUpdateSlice, HloOpcode::kReshape, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); } @@ -930,9 +932,10 @@ ENTRY main { return result; } -INSTANTIATE_TEST_CASE_P(GatherLoopFusionTestInstantiation, GatherLoopFusionTest, - ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), - GatherLoopFusionTestSpec::Name); +INSTANTIATE_TEST_SUITE_P(GatherLoopFusionTestInstantiation, + GatherLoopFusionTest, + ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), + GatherLoopFusionTestSpec::Name); } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index c291bf2d1ba2eaff4192051840768c037bece86f..95b8025f873c56bea063ff258d4abd6614257d85 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -46,8 +46,7 @@ static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { for (auto* user : instruction->users()) { optional operand_idx = ProfitableToMakeDotOperandColumnMajor(*user); if (!operand_idx || user->operand(*operand_idx) != instruction || - std::count(user->operands().begin(), user->operands().end(), - instruction) != 1) { + absl::c_count(user->operands(), instruction) != 1) { return false; } } @@ -94,60 +93,38 @@ static Shape ColMajorShape(const Shape& old_shape) { return new_shape; } +static bool OperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& instr, + const TargetMachineFeatures& target_machine_features) { + if (instr.opcode() == HloOpcode::kConvolution) { + return PotentiallyImplementedAsEigenConvolution(instr, + target_machine_features); + } else if (instr.opcode() == HloOpcode::kDot) { + return DotOperandsAndResultMustHaveRowMajorLayout(instr, + target_machine_features); + } + return false; +} + Status CpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { ShouldMakeOperandColMajorCache cache; const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction, - target_machine_features_)) { - const HloInstruction* convolution = instruction; - const HloInstruction* lhs_instruction = convolution->operand(0); - const HloInstruction* rhs_instruction = convolution->operand(1); - - // In order to implement `convolution` with Eigen convolution, the layouts - // of the input, filter, and output need to be row-major. - // - // These constraints are not hard constraints. Ideally, we should decide - // which layouts to choose according to some cost model. - Shape output_shape(RowMajorShape(convolution->shape())); - Shape input_shape(RowMajorShape(lhs_instruction->shape())); - Shape filter_shape(RowMajorShape(rhs_instruction->shape())); - - // Set layouts of the instructions' shapes. - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, convolution, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, convolution, 1)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(output_shape, convolution)); + if (OperandsAndResultMustHaveRowMajorLayout(*instruction, + target_machine_features_)) { + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + RowMajorShape(instruction->shape()), instruction)); + for (int i = 0; i < instruction->operand_count(); i++) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + RowMajorShape(instruction->operand(i)->shape()), instruction, i)); + } } else if (optional op_idx = ShouldMakeOperandColumnMajor(&cache, *instruction)) { const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( ColMajorShape(op->shape()), instruction, *op_idx)); - } else if (PotentiallyImplementedAsEigenDot(*instruction, - target_machine_features_)) { - const HloInstruction* dot = instruction; - // In order to implement `dot` with Eigen dot, the layouts of the lhs, - // rhs, and output need to be row-major. - // - // These constraints are not hard constraints. Ideally, we should decide - // which layouts to choose according to some cost model. - Shape output_shape(RowMajorShape(dot->shape())); - - const HloInstruction* lhs_instruction = dot->operand(0); - Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); - - const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); - - // Set layouts of the instructions' shapes. - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot)); } else { for (int64 operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { @@ -160,7 +137,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( continue; } // Skip operands with non-array shapes. - if (!ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + if (!instruction->operand(operand_no)->shape().IsArray()) { continue; } Shape operand_shape( @@ -175,7 +152,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( } // Skip instructions which don't produce array shapes (tuples, opaque, // etc.). - if (!ShapeUtil::IsArray(instruction->shape())) { + if (!instruction->shape().IsArray()) { continue; } } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index 92debb83e33b1400a59e5eef0f90971392ab7b22..ff654c83d61e7cc09ac7839feccaf2bc9cb3c63c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -23,8 +23,8 @@ namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; -const char* const kXlaEnableExperimentalLlvmIrGemm = - "xla_enable_experimental_llvm_ir_gemm"; +const char* const kXlaForceEnableExperimentalLlvmIrGemm = + "xla_force_enable_experimental_llvm_ir_gemm"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -57,10 +57,10 @@ absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config) { return absl::nullopt; } -bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { +bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); - return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; + return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0; } static absl::string_view RemoveSuffix(absl::string_view str, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 47c7eb13b6e4cc05a23f82b8d2a25249f4b82ac0..99e6702d14aed8ffb148adec2bdd02dbc7c3c7e3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -26,7 +26,7 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); -bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); +bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config); absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config); absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index a9febe891b5e9d1eb9e6b297952b50d1d26a3396..d8878e622c0500fc5328aa6c295a9e24a3a037f7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -84,31 +84,8 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; -extern const char* const kKeyValueSortPREDSymbolName = - "__xla_cpu_runtime_KeyValueSortPRED"; -extern const char* const kKeyValueSortS8SymbolName = - "__xla_cpu_runtime_KeyValueSortS8"; -extern const char* const kKeyValueSortU8SymbolName = - "__xla_cpu_runtime_KeyValueSortU8"; -extern const char* const kKeyValueSortS16SymbolName = - "__xla_cpu_runtime_KeyValueSortS16"; -extern const char* const kKeyValueSortU16SymbolName = - "__xla_cpu_runtime_KeyValueSortU16"; -extern const char* const kKeyValueSortF16SymbolName = - "__xla_cpu_runtime_KeyValueSortF16"; -extern const char* const kKeyValueSortS32SymbolName = - "__xla_cpu_runtime_KeyValueSortS32"; -extern const char* const kKeyValueSortU32SymbolName = - "__xla_cpu_runtime_KeyValueSortU32"; -extern const char* const kKeyValueSortF32SymbolName = - "__xla_cpu_runtime_KeyValueSortF32"; -extern const char* const kKeyValueSortS64SymbolName = - "__xla_cpu_runtime_KeyValueSortS64"; -extern const char* const kKeyValueSortU64SymbolName = - "__xla_cpu_runtime_KeyValueSortU64"; -extern const char* const kKeyValueSortF64SymbolName = - "__xla_cpu_runtime_KeyValueSortF64"; - +extern const char* const kKeyValueSortSymbolName = + "__xla_cpu_runtime_KeyValueSort"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index b2e760a224ad8eaa61dae57b0f9cece04a7e54ae..3a2b44d8c1a80128d3577c374e751e73a89e9d59 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -64,18 +64,7 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; -extern const char* const kKeyValueSortPREDSymbolName; -extern const char* const kKeyValueSortS8SymbolName; -extern const char* const kKeyValueSortU8SymbolName; -extern const char* const kKeyValueSortS16SymbolName; -extern const char* const kKeyValueSortU16SymbolName; -extern const char* const kKeyValueSortF16SymbolName; -extern const char* const kKeyValueSortS32SymbolName; -extern const char* const kKeyValueSortU32SymbolName; -extern const char* const kKeyValueSortF32SymbolName; -extern const char* const kKeyValueSortS64SymbolName; -extern const char* const kKeyValueSortU64SymbolName; -extern const char* const kKeyValueSortF64SymbolName; +extern const char* const kKeyValueSortSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index 1ae3aa57111e3a3b7ac18b4907c5c282edf89b7e..4e8c98678309fa4d573f1aac1290c9afc87643a4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -162,11 +162,12 @@ TEST_P(EigenMatMulTest, DoIt) { CheckMatrixMultiply(*a, *b, *c); } -INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest, - ::testing::Combine(::testing::ValuesIn(MatMulShapes), - ::testing::Bool(), ::testing::Bool(), - ::testing::Bool()), - EigenMatMulTest::Name); +INSTANTIATE_TEST_SUITE_P(EigenMatMulTestInstantiaion, EigenMatMulTest, + ::testing::Combine(::testing::ValuesIn(MatMulShapes), + ::testing::Bool(), + ::testing::Bool(), + ::testing::Bool()), + EigenMatMulTest::Name); #ifdef INTEL_MKL class MKLMatMulTest : public CpuRuntimeTest, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 1457582ac19c27e5c3150b4667e6af505345a6bd..3361a5973f5e8c91802b26d68477347b196d3cac 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -97,7 +97,7 @@ Status CpuTransferManager::TransferLiteralToInfeed( VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); - if (!ShapeUtil::IsTuple(shape)) { + if (!shape.IsTuple()) { int64 size = GetByteSizeRequirement(shape); return TransferBufferToInfeed(executor, size, literal.untyped_data()); } @@ -178,7 +178,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, Status CpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, MutableBorrowingLiteral literal) { - if (!ShapeUtil::IsTuple(literal_shape)) { + if (!literal_shape.IsTuple()) { int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from // literal_shape.dimensions() to the array slice on 2017-07-10. diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc index 3ae64142cd7e32d3aa8d50870efaf94698c06440..c3c6847b7b77e2fb0470630815de9f5d7a6c5b9c 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.cc +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -77,17 +77,16 @@ StatusOr Disassembler::DisassembleObjectFile( } // Sort the symbols in increasing address order. - std::sort( - symbols.begin(), symbols.end(), - [](const llvm::object::SymbolRef& a, const llvm::object::SymbolRef& b) { - // getAddress returns a Expected object. Assert there is no error - // before extracting the address. - llvm::Expected a_address_or_error = a.getAddress(); - CHECK(a_address_or_error); - llvm::Expected b_address_or_error = b.getAddress(); - CHECK(b_address_or_error); - return a_address_or_error.get() < b_address_or_error.get(); - }); + absl::c_sort(symbols, [](const llvm::object::SymbolRef& a, + const llvm::object::SymbolRef& b) { + // getAddress returns a Expected object. Assert there is no error + // before extracting the address. + llvm::Expected a_address_or_error = a.getAddress(); + CHECK(a_address_or_error); + llvm::Expected b_address_or_error = b.getAddress(); + CHECK(b_address_or_error); + return a_address_or_error.get() < b_address_or_error.get(); + }); // Construct ArrayRef pointing to section contents. llvm::StringRef section_content_string; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 97f9b85a606e140fd7f3b1e3ecfb0dd5ba289f03..0fecbaf391bc3122646af30b508fc1a88b6641e9 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -26,7 +26,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -41,932 +44,165 @@ namespace xla { using llvm_ir::SetToFirstInsertPoint; namespace cpu { - namespace { -// Provides tiled access to an in-memory rank 2 array. -class MemoryTile { - public: - // Constructs a MemoryTile that can operate on tiles consisting of - // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at - // `major_dim_offset` in the major dimension. The tile size along the minor - // dimension is the vector size, and that is implicitly determined by `vsl`. - MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b, - llvm::Value* matrix, int64 matrix_size_along_minor_dim, - llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl), b_(b) { - pointers_.reserve(tile_size_along_major_dim); - for (int64 i = 0; i < tile_size_along_major_dim; i++) { - llvm::Value* total_offset = - b->CreateMul(b->getInt64(matrix_size_along_minor_dim), - b->CreateAdd(b->getInt64(i), major_dim_offset)); - pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); - } - } - - // Load a tile consisting of `tile_size_along_major_dim` vectors from position - // {major: `major_dim_offset`, minor: `minor_dim_offset`}. - // - // Note: `major_dim_offset` is a parameter to the constructor. - std::vector LoadTile(llvm::Value* minor_dim_offset) const { - std::vector result; - result.reserve(pointers_.size()); - for (const auto& pointer : pointers_) { - result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); - } - return result; - } - - // Stores `tile` to position {major: `major_dim_offset`, minor: - // `minor_dim_offset`}. - // - // Note: `major_dim_offset` is a parameter to the constructor. - void StoreTile(absl::Span tile, - llvm::Value* minor_dim_offset) const { - CHECK_EQ(tile.size(), pointers_.size()); - for (int64 i = 0; i < pointers_.size(); i++) { - vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); - } - } - - // Loads a tile of size [`tile_size_along_major_dim`, - // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, - // minor: `minor_dim_offset`} and then broadcasts each element into a vector - // of size vsl_.vector_size(). The (i,j)'th element of the return value is - // the (i,j)'th element in the tile broadcasted into an LLVM vector. - // - // Note: `major_dim_offset` is a parameter to the constructor. - std::vector> LoadBroadcastTile( - llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { - std::vector> result; - result.resize(pointers_.size()); - for (int64 i = 0; i < pointers_.size(); i++) { - for (int64 j = 0; j < tile_size_along_middle_dim; j++) { - result[i].push_back(vsl_->LoadBroadcast( - pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j)))); - } - } - return result; - } - - private: - VectorSupportLibrary* vsl_; - llvm::IRBuilder<>* b_; - std::vector pointers_; -}; - -// The base class for the classes representing the GEMV emitter configurations. -// -// The IR emitted (modulo the LLVM values representing the input and output -// buffers) by the row major and column major GEMV emitters should be a function -// of their configuration. This is important because their configuration is -// used as a key to cache the generated IR. -class GemvConfig { - public: - // Mixin for convenience. - template - struct User { - public: - PrimitiveType scalar_type() const { - return derived().config().scalar_type(); - } - int64 tile_rows() const { return derived().config().tile_rows(); } - int64 tile_cols() const { return derived().config().tile_cols(); } - int64 m() const { return derived().config().m(); } - int64 k() const { return derived().config().k(); } - int64 has_addend() const { return derived().config().has_addend(); } - - private: - const T& derived() const { return *static_cast(this); } - }; +// Returns true if we should call into multi-threaded Eigen routines. +bool ShouldUseMultiThreadedEigen(const HloModuleConfig& config) { + return config.debug_options().xla_cpu_multi_thread_eigen(); +} - PrimitiveType scalar_type() const { return scalar_type_; } - int64 tile_rows() const { return tile_rows_; } - int64 tile_cols() const { return tile_cols_; } - int64 m() const { return m_; } - int64 k() const { return k_; } - bool has_addend() const { return has_addend_; } - - string GetCacheKey() const { - return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", - tile_rows(), "_", tile_cols(), "_", m(), "_", k(), - has_addend() ? "_with_addend" : ""); +// Represents a dot operation. We use this in lieu of an `HloInstruction` +// because we want to be able to create this for the "inner" dot operation in a +// batch dot, for which there is no separate HLO instruction. +struct DotInfo { + Shape lhs_shape; + Shape rhs_shape; + Shape result_shape; + DotDimensionNumbers dim_nums; + + DotInfo() = default; + + explicit DotInfo(const HloInstruction& instr) { + CHECK_EQ(instr.opcode(), HloOpcode::kDot); + lhs_shape = instr.operand(0)->shape(); + rhs_shape = instr.operand(1)->shape(); + result_shape = instr.shape(); + dim_nums = instr.dot_dimension_numbers(); } - - protected: - explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, - int64 tile_cols, int64 m, int64 k, bool has_addend) - : name_(std::move(name)), - scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), - has_addend_(has_addend) {} - - private: - string name_; - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; - bool has_addend_; }; -// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the -// layout of the vector does not matter). This implementation uses a tiling -// scheme to improve performance. -// -// We logically separate the LHS matrix into four segments: -// -// +----------------------+---+ -// | | | -// | | | -// | A | B | -// | | | -// | | | -// | | | -// +----------------------+---+ -// | C | D | -// +----------------------+---+ -// -// where A is the largest submatrix of the LHS that can be evenly dividied into -// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: -// -// +---+---+---+---+ +--+--+--+--+ -// |M00|M10|M20|M30| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M01|M11|M21|M31| and |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M02|M12|M22|M32| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M03|M13|M23|M33| |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// -// (Legend: rows are horizontal and columns are vertical; and each column is one -// llvm::Value of a vector type) -// -// where: -// -// a. The left tile is from the column major left matrix. -// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] -// vector loaded from the RHS vector. -// -// As we iterate through the column dimension, we compute the change to the -// result vector by an elementwise multiplication between the two tiles above -// followed by a reduction along the major dimension: -// -// +-----------------------------------+ -// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | -// +-----------------------------------+ -// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | -// Result[R:R+4] += +-----------------------------------+ -// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | -// +-----------------------------------+ -// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | -// +-----------------------------------+ -// -// Where R is the starting row for the tile. -// -// We have an inner epilogue loop to deal with the "C" submatrix and an outer -// epilogue loop to deal with the B,D submarix. -// -// TODO(sanjoy): We should investigate if using gather loads and scatter stores -// can be used here have the same inner loop for both column-major and row-major -// matrix-vector products. -class ColumnMajorMatrixVectorProductEmitter - : public GemvConfig::User { - public: - class Config : public GemvConfig { - public: - explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, - int64 m, int64 k, bool has_addend) - : GemvConfig(/*name=*/"col_major_gemv", scalar_type, - /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, - /*k=*/k, /*has_addend=*/has_addend) {} - }; - - ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result, - llvm::IRBuilder<>* b) - : config_(config), - lhs_(lhs), - rhs_(rhs), - addend_(addend), - result_(result), - b_(b), - ksl_(b_), - vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") { - CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); - CHECK(!has_addend() || addend != nullptr); - } - - void Emit(); - - const Config& config() const { return config_; } - - private: - void EmitOuterLoopBody(llvm::Value* column, int64 column_count, - bool is_first_column); - - MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { - return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/m(), - /*major_dim_offset=*/column_start, - /*tile_size_along_major_dim=*/column_count); - } - - // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous - // sequence of `count` values, each one broadcasted to the vector width. - std::vector LoadRhsTile(llvm::Value* offset, int64 count) { - llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); - std::vector result; - result.reserve(count); - for (int64 i = 0; i < count; i++) { - result.push_back(vsl_.LoadBroadcast(base_pointer, i)); - } - return result; - } - - void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, - const std::vector& rhs_tile, - int64 columns, bool is_first_column); - - void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, - bool is_first_tiled_column); - - Config config_; - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* addend_; - llvm::Value* result_; - llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; - VectorSupportLibrary vsl_; +// Dictates how a dot operation is implemented. +enum class DotImplementationStrategy { + // The dot operation is lowered into LLVM IR that implements a naive nested + // loop that computes the result one element at a time. This is our + // "fallback"; we don't really want this to kick in for any non-trival dot + // operation. + kNaiveLlvmIr, + + // The dot operation is lowered into LLVM IR that implements a tiled + // Matrix*Vector operation. This strategy also allows fusing in a bias add + // into the dot. The matrix can be row major or column major, both are + // supported. + kTiledLlvmIrGemv, + + // The dot operation is lowered into LLVM IR that implemetns a tiled + // Matrix*Matrix operation. No fusions are supported. The two inputs + // and the output have to be row major. + kTiledLlvmIrGemm, + + // The dot operation is lowered into a call into an Eigen routine. No fusions + // are supported today. The two inputs and the output have to be row major. + // However, we do allow transposing either the LHS or the RHS as part of the + // GEMM -- we expose this flexibility as flexibility in the contraction + // dimensions, but we can also see this as flexibility in the input layouts. + kEigen, }; -void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( - llvm::Value* column, int64 column_count, bool is_first_column) { - MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, - /*column_count=*/column_count); - - std::vector rhs_tile = - LoadRhsTile(column, /*count=*/column_count); - EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, - /*columns=*/column_count, is_first_column); - EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); -} - -void ColumnMajorMatrixVectorProductEmitter::Emit() { - // See the comment on the class declaration for the algorithm used here. - int64 column_remainder = k() % tile_cols(); - int64 column_limit = k() - column_remainder; - - ksl_.ForReturnVoid("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols(), is_first_column); - }); - - if (column_remainder != 0) { - EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, - column_limit == 0); - } -} - -void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, - int64 columns, bool is_first_column) { - int64 row_limit = m() - (m() % tile_rows()); - - ksl_.ForReturnVoid( - "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows(), [&](llvm::Value* row) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = - is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) - : vsl_.GetZeroVector()) - : vsl_.LoadVector(result_, row); - for (int i = 0; i < columns; i++) { - accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); - } - vsl_.StoreVector(accumulator, result_, row); - }); -} - -void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( - llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { - int64 row_start = m() - (m() % tile_rows()); - if (row_start == m()) { - return; - } - - llvm::Value* columns_llvm = b_->getInt64(columns); - - // for (col = current_tile_col; col < (columns + current_tile_col); col++) - // for (row = row_start, row < m_; row++) { - // result[row] += lhs[row, col] * rhs[col] - // // Also take into account that if col is 0 then result[row] is not - // // initialized. - // } - - ksl_.ForReturnVoid( - "dot.inner.epilg.outer", /*start=*/current_tile_col, - /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), - /*step=*/1, /*peel_first_iteration=*/false, - [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { - llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); - llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); - llvm::Value* lhs_base_pointer = - vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.ForReturnVoid( - "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), - /*step=*/1, [&](llvm::Value* scalar_row) { - llvm::Value* product = vsl_.Mul( - vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); - llvm::Value* setting_result_first_time = b_->CreateAnd( - is_first_scalar_col, b_->getInt1(is_first_tiled_column)); - ksl_.IfReturnVoid( - setting_result_first_time, - /*true_block_generator=*/ - [&]() { - if (addend_) { - vsl_.StoreScalar( - vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), - product), - result_, scalar_row); - } else { - vsl_.StoreScalar(product, result_, scalar_row); - } - }, - /*false_block_generator=*/ - [&]() { - vsl_.StoreScalar( - vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), - result_, scalar_row); - }); - }); - }); -} +// Returns the implementation strategy for a dot with the configuration +// `dot_info`. +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features); -// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the -// layout of the vector does not matter). This implementation uses a tiling -// scheme to improve performance. -// -// We logically separate the LHS matrix into four segments: -// -// +----------------------+---+ -// | | | -// | | | -// | A | B | -// | | | -// | | | -// | | | -// +----------------------+---+ -// | C | D | -// +----------------------+---+ -// -// where A is the largest submatrix of the LHS that can be evenly dividied into -// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: -// -// +---+---+---+---+ -// |M00|M10|M20|M30| -// +---+---+---+---+ +--+--+--+--+ -// |M01|M11|M21|M31| and |V0|V1|V2|V3| -// +---+---+---+---+ +--+--+--+--+ -// |M02|M12|M22|M32| -// +---+---+---+---+ -// |M03|M13|M23|M33| -// +---+---+---+---+ -// -// (Legend: rows are horizontal and columns are vertical; and each row is one -// llvm::Value of a vector type) -// -// where: -// -// a. The left tile is loaded from the row major left matrix. -// b. The right vector is loaded from the RHS vector. -// -// We keep 4 vector accumulators accumulating the following four vector -// expressions as we iterate over the row dimension: -// -// +------+------+------+------+ -// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) -// +------+------+------+------+ -// -// In the end we do a horizontal reduction over these 4 vector accumulators to -// get 4 values in the result vector. -// -// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer -// epilogue loop to deal with the C,D submatrix. -class RowMajorMatrixVectorProductEmitter - : public GemvConfig::User { +// Helper class for emitting LLVM IR to perform the dot operation. +class DotOpEmitter { public: - class Config : public GemvConfig { - public: - explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, - int64 m, int64 k, bool has_addend) - : GemvConfig(/*name=*/"row_major_gemv", scalar_type, - /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, - /*k=*/k, /*has_addend=*/has_addend) {} - }; - - RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result, llvm::IRBuilder<>* b) - : config_(config), - lhs_(lhs), - rhs_(rhs), - addend_(addend), - result_(result), - b_(b), - ksl_(b_), - vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") { - CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); - CHECK(!has_addend() || addend != nullptr); - } - - void Emit(); - - const Config& config() const { return config_; } + explicit DotOpEmitter(DotInfo dot_info, string dot_hlo_name, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); + + // Emits the IR to perform the dot operation. + Status Emit(); private: - MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { - return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/k(), - /*major_dim_offset=*/row_start, - /*tile_size_along_major_dim=*/row_count); - } - - void EmitOuterLoopBody(llvm::Value* row, int64 row_count); - - void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, - std::vector* vector_accumulators); - - void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, - std::vector* scalar_accumulators); - - Config config_; - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* addend_; - llvm::Value* result_; - llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; - VectorSupportLibrary vsl_; -}; - -void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, - int64 row_count) { - MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, - /*row_count=*/row_count); - std::vector vector_accumulators; - std::vector scalar_accumulators; - for (int i = 0; i < row_count; i++) { - vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); - scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); - } - EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, - &vector_accumulators); - EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, - &scalar_accumulators); - - std::vector accumulator_values; - std::transform( - vector_accumulators.begin(), vector_accumulators.end(), - std::back_inserter(accumulator_values), - [](const VectorVariable& vector_var) { return vector_var.Get(); }); - - std::vector horizontal_sums; - if (row_count == vsl_.vector_size()) { - if (addend_) { - horizontal_sums = vsl_.ComputeHorizontalSums( - std::move(accumulator_values), vsl_.LoadVector(addend_, row)); - } else { - horizontal_sums = - vsl_.ComputeHorizontalSums(std::move(accumulator_values)); - } - } else { - horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); - } - - for (int i = 0; i < row_count; i++) { - llvm::Value* result_value = - vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); - llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row); - if (addend_ && row_count != vsl_.vector_size()) { - result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); - } - vsl_.StoreScalar(result_value, result_, offset); - } -} + // Emits instructions to perform a scalar dot product (a multiply of the + // LHS and RHS) and store the results in the target. + Status EmitScalarDot(); -void RowMajorMatrixVectorProductEmitter::Emit() { - // See the comment on the class declaration for the algorithm used here. - int64 row_remainder = m() % tile_rows(); - int64 row_limit = m() - row_remainder; + // Emits a call to the CPU runtime to perform the matrix multiply. + Status EmitCallToRuntime(); - ksl_.ForReturnVoid( - "dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); - - if (row_remainder != 0) { - EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); - } -} + // Represents the dimensions of a matrix-matrix multiply operation. + struct MatMultDims { + // The number of rows in the LHS. + int64 m; -void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - MemoryTile* lhs_memory_tile, int64 rows, - std::vector* vector_accumulators) { - int64 column_limit = k() - (k() % tile_cols()); - - ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols(), [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set(vsl_.Add( - old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); -} + // The number of columns in the LHS, which is also must be equal to the + // number of rows in the RHS. + int64 k; -void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( - llvm::Value* current_tile_row, int64 rows, - std::vector* scalar_accumulators) { - int64 column_start = k() - (k() % tile_cols()); - if (column_start == k()) { - return; - } + // The number of columns on the RHS. + int64 n; - for (int r = 0; r < rows; r++) { - llvm::Value* total_offset = b_->CreateMul( - b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); - llvm::Value* lhs_base_pointer = - vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.ForReturnVoid( - "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), - /*step=*/1, [&](llvm::Value* scalar_col) { - llvm::Value* product = - vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), - vsl_.LoadScalar(rhs_, scalar_col)); - llvm::Value* old_value = (*scalar_accumulators)[r].Get(); - (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); - }); - } -} + // True if the LHS matrix is column major. + bool lhs_column_major; -// This class implements a tiled matrix multiplication algorithm, intended for -// multiplying small matrices that don't need cache tiling. -// -// In the future this can be used as the innermost GEBP loop in a GEMM kernel as -// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of -// high-performance matrix multiplication." ACM Transactions on Mathematical -// Software (TOMS) 34.3 (2008): 12.". -// -// This only supports canonical dot operations (i.e. where the lhs contraction -// dimension is 1 and the rhs contraction dimension is 0) over row major -// matrices. -class TiledSmallGemmEmitter { - public: - // Describe the dimensions of the kernel. - class Dimensions { - public: - explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + // True if the LHS contraction dimension is not 1. + bool lhs_non_canonical; - int64 m() const { return m_; } - int64 k() const { return k_; } - int64 n() const { return n_; } + // True if the RHS matrix is column major. + bool rhs_column_major; - string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } + // True if the RHS contraction dimension is not 0. + bool rhs_non_canonical; - private: - const int64 m_; - const int64 k_; - const int64 n_; + // True if the result matrix is column major. + bool target_column_major; }; - // Represents the configuration of the emitter. The LLVM IR emitted by the - // emitter, modulo the LLVM values holding the input and output buffers, must - // be a function of the instance of `Config` passed to it. - // - // `dims` holds the matrix multiplication dimensions. - // - // `max_vectorization_width` is the maximum vector width (i.e. the width of - // the largest vector register we will use). This can be larger than the - // largest vector register supported by the machine -- LLVM will legalize - // these large vector widths into legally sized vectors. - // - // `max_vector_count` is the maximum number of vectors of size - // `max_vectorization_width` that we will attempt to process at once. - // - // `min_vectorization_width` is the smallest vector width the emitter will use - // -- below that it will devolve to using a scalar loop. - // - // The innermost reduction loop executes the matrix multiply in tiles of size - // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, - // ] in the RHS. - class Config { - public: - explicit Config(PrimitiveType scalar_type, Dimensions dims, - int64 max_vectorization_width, int64 max_vector_count, - int64 min_vectorization_width, int64 tile_size_m, - int64 tile_size_k) - : scalar_type_(scalar_type), - dims_(dims), - max_vectorization_width_(max_vectorization_width), - max_vector_count_(max_vector_count), - min_vectorization_width_(min_vectorization_width), - tile_size_m_(tile_size_m), - tile_size_k_(tile_size_k) {} - - string GetCacheKey() const { - return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", - dims().ToString(), "_", max_vectorization_width(), - "_", min_vectorization_width(), "_", tile_size_m(), - "_", tile_size_k()); - } + // Get the MatMultDims instance for the dot product this DotOpEmitter + // represents. Precondition: the dot is of rank 2 (and thus its operands are + // of rank 2 as well). + MatMultDims GetMatMultDims() const; - PrimitiveType scalar_type() const { return scalar_type_; } - Dimensions dims() const { return dims_; } - int64 max_vectorization_width() const { return max_vectorization_width_; } - int64 max_vector_count() const { return max_vector_count_; } - int64 min_vectorization_width() const { return min_vectorization_width_; } - - int64 tile_size_m() const { return tile_size_m_; } - int64 tile_size_k() const { return tile_size_k_; } - - private: - PrimitiveType scalar_type_; - Dimensions dims_; - int64 max_vectorization_width_; - int64 max_vector_count_; - int64 min_vectorization_width_; - int64 tile_size_m_; - int64 tile_size_k_; - }; + // Lowers the dot operation as a tiled Matrix*Vector loop. + void EmitTiledLlvmIrGemv(); - // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies - // `lhs` with `rhs` and stores the result in `result`. - explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* b) - : lhs_(lhs), - rhs_(rhs), - result_(result), - config_(config), - b_(b), - ksl_(b_) { - CHECK(max_vectorization_width() > 0 && - IsPowerOfTwo(static_cast(max_vectorization_width()))); - CHECK_GT(max_vector_count(), 0); - CHECK(min_vectorization_width() > 0 && - IsPowerOfTwo(static_cast(min_vectorization_width()))); - CHECK_GE(max_vectorization_width(), min_vectorization_width()); - CHECK_GT(tile_size_k(), 0); - } + // Lowers the dot operation as a tiled Matrix*Matrix loop. + void EmitTiledLlvmIrGemm(); - void Emit(); + // Lowers the dot operation as a naive nested loop that computes the result + // one element at a time. + void EmitNaiveLlvmIrGemm(); - private: - // The HandleResiduesOnX helpers split the iteration space for dimension X - // into a multiple of the tile size on dimension X and an epilogue. These - // helpers ultimately call into `EmitTiledGemm` for emitting the - // tiled GEMM kernel. - - void HandleResiduesOnN(); - void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, - llvm::Value* n_end); - void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, - llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end); - - // This emits a tiled GEMM kernel. For a detailed description see the comment - // on the implementation. - void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, - llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end, - int64 tile_size_m, llvm::Value* m_start, - llvm::Value* m_end); - - llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); } - - Config config() const { return config_; } - Dimensions dims() const { return config().dims(); } - - int64 max_vectorization_width() const { - return config().max_vectorization_width(); + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector + // registers. + int64 GetGemvTilingFactor() const { + const int64 kDefaultTilingFactor = 8; + return options::LlvmIrGemvTilingFactor(hlo_module_config_) + .value_or(kDefaultTilingFactor); } - int64 max_vector_count() const { return config().max_vector_count(); } - int64 min_vectorization_width() const { - return config().min_vectorization_width(); - } - int64 tile_size_m() const { return config().tile_size_m(); } - int64 tile_size_k() const { return config().tile_size_k(); } - PrimitiveType scalar_type() const { return config().scalar_type(); } - llvm::Value* lhs_; - llvm::Value* rhs_; - llvm::Value* result_; - Config config_; + std::tuple GetGemmTileSize() const { + // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz + // + // TODO(b/80093688): Tune for other architectures and centralize this + // information in one place. + const std::tuple kDefaultTileSize = + std::tuple(11, 9, 1); + return options::LlvmIrGemmTileSize(hlo_module_config_) + .value_or(kDefaultTileSize); + } + DotInfo dot_info_; + string dot_hlo_name_; + const llvm_ir::IrArray& target_array_; + const llvm_ir::IrArray& lhs_array_; + const llvm_ir::IrArray& rhs_array_; + const llvm_ir::IrArray* addend_array_; + llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* b_; - KernelSupportLibrary ksl_; + const HloModuleConfig& hlo_module_config_; + const TargetMachineFeatures& target_machine_features_; }; - -void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } - -void TiledSmallGemmEmitter::HandleResiduesOnN() { - // We can only iterate the `n` dimension for an extent that is divisible by - // the vectorization width. So we emit an outer loop that first processes the - // largest extent in `n` that is divisible by max_vectorization_width, then - // the largest remaining extent that is divisible by max_vectorization_width / - // 2 etc. - - int64 current_vectorization_width = - max_vector_count() * max_vectorization_width(); - int64 current_vector_count = max_vector_count(); - - int64 n_start = 0; - while (n_start != dims().n() && - current_vectorization_width >= min_vectorization_width()) { - int64 n_end = dims().n() - (dims().n() % current_vectorization_width); - if (n_start != n_end) { - VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, - "gemm"); - HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); - n_start = n_end; - } - if (current_vector_count == 1) { - current_vectorization_width /= 2; - } else { - current_vector_count--; - current_vectorization_width = - current_vector_count * max_vectorization_width(); - } - } - - if (n_start != dims().n()) { - VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); - ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { - llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); - HandleResiduesOnK(&vsl, n_i, n_i_next); - }); - } -} - -void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, - llvm::Value* n_start, - llvm::Value* n_end) { - int64 k_start = 0; - int64 k_end = dims().k() - (dims().k() % tile_size_k()); - if (k_end != k_start) { - HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), - n_start, n_end); - k_start = k_end; - } - - if (k_start != dims().k()) { - HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), - GetInt64(dims().k()), n_start, n_end); - } -} - -void TiledSmallGemmEmitter::HandleResiduesOnM( - VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, - llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { - const int64 m_end = dims().m() - dims().m() % tile_size_m(); - EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), - GetInt64(0), GetInt64(m_end)); - - if (m_end != dims().m()) { - EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, - dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); - } -} - -// The loop structure is: -// -// Iterate over dimension M as m: -// Iterate over dimension N as n: -// Iterate over dimension K as k: -// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) -// -// I.e. a just a tiled version of a "naive" GEMM. -// -// The tiling scheme is as follows: -// -// Let the LHS be: -// -// +----+----+----+ -// | a0 | b0 | c0 | . -// +----+----+----+ . -// | a1 | b1 | c1 | . -// +----+----+----+ -// .. .. -// -// and the RHS be: -// -// +----+----+----+----+ -// | p0 | p1 | p2 | p3 | . -// +----+----+----+----+ . -// | q0 | q1 | q2 | q3 | . -// +----+----+----+----+ -// | r0 | r1 | r2 | r3 | . -// +----+----+----+----+ . -// ...... ...... -// -// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted -// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] -// matrix that we can increment the result matrix by. -// -// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank -// 3 array, L, of dimension [2,3,4]: -// -// L[0,_,_] * L[1,_,_] -// * -// +----+----+----+----+ * +----+----+----+----+ -// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | -// +----+----+----+----+ * +----+----+----+----+ -// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | -// +----+----+----+----+ * +----+----+----+----+ -// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | -// +----+----+----+----+ * +----+----+----+----+ -// -// -// Then we FMA L[0,_,_] with the RHS to get the first row of the result and -// L[1,_,_] with the RHS to get the second row of the result. For example, -// L[0,_,_] is computed as: -// -// +----+----+----+----+ +----+----+----+----+ -// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + -// +----+----+----+----+ +----+----+----+----+ -// -// +----+----+----+----+ +----+----+----+----+ -// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + -// +----+----+----+----+ +----+----+----+----+ -// -// +----+----+----+----+ +----+----+----+----+ -// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | -// +----+----+----+----+ +----+----+----+----+ -// -// to get: -// -// +-------------------+-------------------+-------------------+--------- -// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... -// +-------------------+-------------------+-------------------+--------- -void TiledSmallGemmEmitter::EmitTiledGemm( - VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, - llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, - int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { - ksl_.ForReturnVoid( - "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { - MemoryTile result_memory_tile( - vsl, b_, /*matrix=*/result_, - /*matrix_size_along_minor_dim=*/dims().n(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/dims().k(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - ksl_.ForReturnVoid( - "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { - TileVariable result_tile_var(vsl, - result_memory_tile.LoadTile(n_i)); - ksl_.ForReturnVoid( - "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { - MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, - tile_size_k); - std::vector> lhs_tile = - lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); - std::vector rhs_tile = - rhs_memory_tile.LoadTile(n_i); - std::vector result_tile = - result_tile_var.Get(); - for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { - for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { - result_tile[r_m_i] = - vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], - result_tile[r_m_i]); - } - } - result_tile_var.Set(result_tile); - }); - - result_memory_tile.StoreTile(result_tile_var.Get(), n_i); - }); - }); -} - } // namespace -DotOpEmitter::DotOpEmitter(const HloInstruction& dot, +DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, @@ -975,7 +211,8 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) - : dot_(dot), + : dot_info_(std::move(dot_info)), + dot_hlo_name_(std::move(dot_hlo_name)), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), @@ -985,58 +222,9 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} -/* static */ Status DotOpEmitter::EmitDotOperation( - const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) { - PrimitiveType type = target_array.GetShape().element_type(); - TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); - DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array, - addend_array, executable_run_options_value, b, - hlo_module_config, target_machine_features); - return dot_emitter.Emit(); -} - -bool DotOpEmitter::EmitSmallGemmIfProfitable( - const DotOpEmitter::MatMultDims& mat_mult_dims) { - if (ShouldUseMultiThreadedEigen()) { - return false; - } - - if (!EnableExperimentalLlvmIrGemm()) { - // TODO(sanjoy): We should make these numbers micro-arch specific. - bool small_gemm = mat_mult_dims.k <= 128 && - ((mat_mult_dims.m <= 32 && mat_mult_dims.n <= 128) || - (mat_mult_dims.m <= 128 && mat_mult_dims.n <= 32)); - if (!small_gemm) { - return false; - } - } - - if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { - return false; - } - - PrimitiveType primitive_type = dot_.shape().element_type(); - - switch (primitive_type) { - default: - return false; - - case F32: - case F64: - case S32: - case S64: - break; - } - - if (!(mat_mult_dims.lhs_column_major == mat_mult_dims.rhs_column_major && - mat_mult_dims.rhs_column_major == mat_mult_dims.target_column_major)) { - return false; - } +void DotOpEmitter::EmitTiledLlvmIrGemm() { + PrimitiveType primitive_type = dot_info_.result_shape.element_type(); + MatMultDims mat_mult_dims = GetMatMultDims(); llvm::Value* lhs = lhs_array_.GetBasePointer(); llvm::Value* rhs = rhs_array_.GetBasePointer(); @@ -1051,9 +239,8 @@ bool DotOpEmitter::EmitSmallGemmIfProfitable( } int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - b_->CreateMemSet( - target, b_->getInt8(0), size_bytes, - target_machine_features_.minimum_alignment_for_allocation(size_bytes)); + b_->CreateMemSet(target, b_->getInt8(0), /*Size=*/size_bytes, + /*Align=*/1); int64 max_target_vector_width = target_machine_features_.vector_register_num_elements( @@ -1063,47 +250,28 @@ bool DotOpEmitter::EmitSmallGemmIfProfitable( std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = GetGemmTileSize(); - TiledSmallGemmEmitter::Config config( - /*scalar_type=*/primitive_type, - TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, - /*max_vectorization_width=*/max_target_vector_width, - /*max_vector_count=*/tile_size_n_in_vector_width, - /*min_vectorization_width=*/std::min(4, max_target_vector_width), - /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); - - VLOG(2) << "Emitting GEMM kernel in LLVM IR with config " - << config.GetCacheKey(); - const bool enable_fast_math = hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); - KernelSupportLibrary::EmitAndCallOutlinedKernel( + EmitSmallGemm( + /*scalar_type=*/primitive_type, + /*m=*/m, /*k=*/k, /*n=*/n, + /*max_vectorization_width=*/max_target_vector_width, + /*max_vector_count=*/tile_size_n_in_vector_width, + /*min_vectorization_width=*/std::min(4, max_target_vector_width), + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k, /*lhs=*/lhs, + /*rhs=*/rhs, /*result=*/target, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs, - rhs, target, - [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { - TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, - /*rhs=*/rhs, - /*result=*/target, b_); - small_gemm_emitter.Emit(); - }); - - return true; + /*optimize_for_size=*/optimize_for_size); } -bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { - if (dot_.shape().dimensions_size() != 2) { - return false; - } - - PrimitiveType primitive_type = dot_.shape().element_type(); +void DotOpEmitter::EmitTiledLlvmIrGemv() { + PrimitiveType primitive_type = dot_info_.result_shape.element_type(); - if (!primitive_util::IsFloatingPointType(primitive_type) && - !primitive_util::IsIntegralType(primitive_type)) { - return false; - } + CHECK(primitive_util::IsFloatingPointType(primitive_type) || + primitive_util::IsIntegralType(primitive_type)); MatMultDims mat_mult_dims = GetMatMultDims(); bool is_column_major_matrix_vector = false; @@ -1144,9 +312,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } } - if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return EmitSmallGemmIfProfitable(mat_mult_dims); - } + CHECK(is_column_major_matrix_vector || is_row_major_matrix_vector); int64 tiling_factor = GetGemvTilingFactor(); CHECK_GT(tiling_factor, 0); @@ -1178,44 +344,27 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - ColumnMajorMatrixVectorProductEmitter::Config config( + EmitColumnMajorGemv( /*scalar_type=*/primitive_type, /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor, - /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); - - KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, + /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, + /*result=*/result_op, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), - lhs_op, rhs_op, - addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, - llvm::Value* addend_op, llvm::Value* result_op) { - ColumnMajorMatrixVectorProductEmitter emitter( - config, lhs_op, rhs_op, addend_op, result_op, b_); - emitter.Emit(); - }); + /*optimize_for_size=*/optimize_for_size); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - RowMajorMatrixVectorProductEmitter::Config config( + EmitRowMajorGemv( /*scalar_type=*/primitive_type, - /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size, - /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); - - KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*tile_rows=*/tiling_factor, + /*tile_cols=*/vector_register_element_size, + /*m=*/m, /*k=*/k, /*lhs=*/lhs_op, /*rhs=*/rhs_op, + /*addend=*/addend_array_ ? addend_array_->GetBasePointer() : nullptr, + /*result=*/result_op, b_, /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), - lhs_op, rhs_op, - addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, - llvm::Value* addend_op, llvm::Value* result_op) { - RowMajorMatrixVectorProductEmitter emitter(config, lhs_op, rhs_op, - addend_op, result_op, b_); - emitter.Emit(); - }); + /*optimize_for_size=*/optimize_for_size); } - - return true; } Status DotOpEmitter::Emit() { @@ -1241,11 +390,6 @@ Status DotOpEmitter::Emit() { // which performs the sum-of-products (the reduction loop) before storing // the result in the output buffer. - // This routine assumes that the dot operation is not in a parallelized - // enclosing computation. - CHECK( - dot_.parent()->root_instruction()->outer_dimension_partitions().empty()); - const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); @@ -1256,27 +400,41 @@ Status DotOpEmitter::Emit() { return EmitScalarDot(); } - if (EmitLlvmIrDotIfProfitable()) { - return Status::OK(); + switch (GetDotImplementationStrategy(hlo_module_config_, dot_info_, + target_machine_features_)) { + case DotImplementationStrategy::kNaiveLlvmIr: + EmitNaiveLlvmIrGemm(); + return Status::OK(); + + case DotImplementationStrategy::kTiledLlvmIrGemv: + EmitTiledLlvmIrGemv(); + return Status::OK(); + + case DotImplementationStrategy::kTiledLlvmIrGemm: + EmitTiledLlvmIrGemm(); + return Status::OK(); + + case DotImplementationStrategy::kEigen: + return EmitCallToRuntime(); } +} +void DotOpEmitter::EmitNaiveLlvmIrGemm() { CHECK_EQ(addend_array_, nullptr); - if (PotentiallyImplementedAsEigenDot(dot_, target_machine_features_)) { - return EmitCallToRuntime(); - } + const Shape& lhs_shape = lhs_array_.GetShape(); + const Shape& rhs_shape = rhs_array_.GetShape(); + const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special // case where the reduction dimension is 0 for both LHS and RHS. This results // in a vector dot product producing a scalar. - int64 lhs_reduction_dimension = - dot_.dot_dimension_numbers().lhs_contracting_dimensions(0); - int64 rhs_reduction_dimension = - dot_.dot_dimension_numbers().rhs_contracting_dimensions(0); + int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0); + int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0); // Verify the reduction dimension in the two operands are the same size. - TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == - rhs_shape.dimensions(rhs_reduction_dimension)); + CHECK_EQ(lhs_shape.dimensions(lhs_reduction_dimension), + rhs_shape.dimensions(rhs_reduction_dimension)); bool lhs_reduction_along_minor_dimension = lhs_reduction_dimension == LayoutUtil::Minor(lhs_shape.layout(), 0); @@ -1286,7 +444,7 @@ Status DotOpEmitter::Emit() { // Create loop nests which loop through the LHS operand dimensions and the RHS // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. - llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), b_); + llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(dot_hlo_name_), b_); llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( @@ -1391,8 +549,6 @@ Status DotOpEmitter::Emit() { // Set the IR builder insert point to the exit basic block of the outer most // loop. b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - - return Status::OK(); } Status DotOpEmitter::EmitScalarDot() { @@ -1406,16 +562,20 @@ Status DotOpEmitter::EmitScalarDot() { llvm::Value* rhs_value = rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_); if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) { -#define REAL(x) b_->CreateExtractValue(x, {0}) -#define IMAG(x) b_->CreateExtractValue(x, {1}) - llvm::Value* real = - b_->CreateFSub(b_->CreateFMul(REAL(lhs_value), REAL(rhs_value)), - b_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value))); - llvm::Value* imag = - b_->CreateFAdd(b_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)), - b_->CreateFMul(IMAG(lhs_value), REAL(rhs_value))); -#undef IMAG -#undef REAL + auto get_real = [&](llvm::Value* x) { + return b_->CreateExtractValue(x, {0}); + }; + + auto get_imag = [&](llvm::Value* x) { + return b_->CreateExtractValue(x, {1}); + }; + + llvm::Value* real = b_->CreateFSub( + b_->CreateFMul(get_real(lhs_value), get_real(rhs_value)), + b_->CreateFMul(get_imag(lhs_value), get_imag(rhs_value))); + llvm::Value* imag = b_->CreateFAdd( + b_->CreateFMul(get_real(lhs_value), get_imag(rhs_value)), + b_->CreateFMul(get_imag(lhs_value), get_real(rhs_value))); result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType()); result = b_->CreateInsertValue(result, real, {0}); result = b_->CreateInsertValue(result, imag, {1}); @@ -1435,7 +595,7 @@ Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded = ShouldUseMultiThreadedEigen(); + bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_); bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; @@ -1483,11 +643,13 @@ Status DotOpEmitter::EmitCallToRuntime() { llvm::Function* function = b_->GetInsertBlock()->getParent(); llvm::Module* module = function->getParent(); - llvm::Function* matmul_func = llvm::cast( - module->getOrInsertFunction(fn_name, matmul_type)); - matmul_func->setCallingConv(llvm::CallingConv::C); - matmul_func->setDoesNotThrow(); - matmul_func->setOnlyAccessesArgMemory(); + llvm::FunctionCallee matmul_func = + module->getOrInsertFunction(fn_name, matmul_type); + if (auto* fn = llvm::dyn_cast(matmul_func.getCallee())) { + fn->setCallingConv(llvm::CallingConv::C); + fn->setDoesNotThrow(); + fn->setOnlyAccessesArgMemory(); + } // The Eigen runtime function expects column-major layout. If the matrices are // row major, then use the following identity to compute the product: @@ -1528,11 +690,11 @@ Status DotOpEmitter::EmitCallToRuntime() { } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { - CHECK_EQ(dot_.shape().dimensions_size(), 2); + CHECK_EQ(dot_info_.result_shape.dimensions_size(), 2); const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); - const DotDimensionNumbers& dim_nums = dot_.dot_dimension_numbers(); + const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; return { /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)), @@ -1546,74 +708,6 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } -// Return whether the given shape is rank 2. -static bool IsRank2(const Shape& shape) { return ShapeUtil::Rank(shape) == 2; } - -// In a gemm operation where output = lhs * rhs, check whether the given shapes -// are valid for the operation. -static bool AreValidGemmShapes( - const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, - const TargetMachineFeatures& target_machine_features) { - // The inputs and the output must - // 1) be matrices with no padding, and - // 2) have an allowed element type. - PrimitiveType output_primitive_type = output_shape.element_type(); - if (!(output_primitive_type == F64 || output_primitive_type == F32 || - output_primitive_type == F16)) { - return false; - } - - if (!(IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape))) { - return false; - } - - auto is_aligned = [&](const Shape& shape) { - return GetMinimumAlignmentForArray(shape, target_machine_features) >= - TargetMachineFeatures::kEigenExpectedTensorAlignment; - }; - - if (!is_aligned(lhs_shape) || !is_aligned(rhs_shape) || - !is_aligned(output_shape)) { - return false; - } - - return true; -} - -bool PotentiallyImplementedAsEigenDot( - const HloInstruction& hlo, - const TargetMachineFeatures& target_machine_features) { - // For certain types of Dot, we can call Eigen - if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - if (ShapeUtil::IsZeroElementArray(lhs_shape) || - ShapeUtil::IsZeroElementArray(rhs_shape)) { - return false; - } - - if (ProfitableToImplementDotInTiledLlvmIr(hlo)) { - return false; - } - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape(), - target_machine_features)) { - const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), - rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); - return true; - } - } - - return false; -} - // For vector-matrix dot products, it is always profitable to make the Rhs // column major. absl::optional ProfitableToMakeDotOperandColumnMajor( @@ -1652,16 +746,319 @@ absl::optional ProfitableToMakeDotOperandColumnMajor( return {}; } -bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { +namespace { +// Return whether the given shape is rank 2. +bool IsRank2(const Shape& shape) { return shape.rank() == 2; } + +bool IsSimpleLayout(const Layout& layout) { + return layout.tiles().empty() && layout.format() == DENSE; +} + +// In a gemm operation where output = lhs * rhs, check whether the given shapes +// are valid for the operation. +bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape, + const TargetMachineFeatures& target_machine_features) { + CHECK(!lhs_shape.has_layout() || IsSimpleLayout(lhs_shape.layout())) + << lhs_shape.DebugString(); + CHECK(!rhs_shape.has_layout() || IsSimpleLayout(rhs_shape.layout())) + << rhs_shape.DebugString(); + CHECK(!output_shape.has_layout() || IsSimpleLayout(output_shape.layout())) + << output_shape.DebugString(); + + switch (output_shape.element_type()) { + case F64: + case F32: + case F16: + return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape); + default: + return false; + } +} + +bool IsAlignedGemm(const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + if (ShapeUtil::IsZeroElementArray(dot_info.lhs_shape) || + ShapeUtil::IsZeroElementArray(dot_info.rhs_shape)) { + return false; + } + + return AreGemmShapes(dot_info.lhs_shape, dot_info.rhs_shape, + dot_info.result_shape, target_machine_features); +} + +bool CanEmitTiledLlvmIrGemm( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + CHECK(IsAlignedGemm(dot_info, target_machine_features)); + + if (ShouldUseMultiThreadedEigen(config)) { + return false; + } + + int m = dot_info.result_shape.dimensions(0); + int k = dot_info.lhs_shape.dimensions( + dot_info.dim_nums.lhs_contracting_dimensions(0)); + int n = dot_info.result_shape.dimensions(1); + + if (!options::ForceEnableExperimentalLlvmIrGemm(config)) { + // TODO(sanjoy): We should make these numbers micro-arch specific. + bool small_gemm = + k <= 128 && ((m <= 32 && n <= 128) || (m <= 128 && n <= 32)); + if (!small_gemm) { + return false; + } + } + + bool lhs_non_canonical = dot_info.dim_nums.lhs_contracting_dimensions(0) == 0; + bool rhs_non_canonical = dot_info.dim_nums.rhs_contracting_dimensions(0) == 1; + + if (lhs_non_canonical || rhs_non_canonical) { + return false; + } + + if (dot_info.result_shape.element_type() == F16) { + // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL + // adding this comment NFC. + return false; + } + + return true; +} + +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features) { + PrimitiveType element_type = dot_info.result_shape.element_type(); // Any Matrix-Vector product of floating point or integral type, or // a transpose-dot fusion of the same can be lowered to a tiled LLVM // IR implementation. - const Shape& shape = dot.shape(); - return shape.dimensions_size() == 2 && - (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) && - (primitive_util::IsFloatingPointType(shape.element_type()) || - primitive_util::IsIntegralType(shape.element_type())); + if (dot_info.result_shape.dimensions_size() == 2 && + (dot_info.result_shape.dimensions(0) == 1 || + dot_info.result_shape.dimensions(1) == 1) && + (primitive_util::IsFloatingPointType(element_type) || + primitive_util::IsIntegralType(element_type))) { + return DotImplementationStrategy::kTiledLlvmIrGemv; + } + + if (IsAlignedGemm(dot_info, target_machine_features)) { + return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features) + ? DotImplementationStrategy::kTiledLlvmIrGemm + : DotImplementationStrategy::kEigen; + } + + return DotImplementationStrategy::kNaiveLlvmIr; +} + +Status EmitNonBatchDotOperation( + DotInfo dot_info, string hlo_name, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { + PrimitiveType type = target_array.GetShape().element_type(); + TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type || + C128 == type); + DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name), + target_array, lhs_array, rhs_array, addend_array, + executable_run_options_value, b, hlo_module_config, + target_machine_features); + return dot_emitter.Emit(); +} + +Shape DropFirstDim(const Shape& shape) { + absl::Span array_shape_dims(shape.dimensions()); + array_shape_dims.remove_prefix(1); + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + array_shape_dims); +} + +Shape CollapseFirstNDims(const Shape& shape, int64 n) { + absl::Span input_shape_dims(shape.dimensions()); + int64 prefix_dim = + std::accumulate(input_shape_dims.begin(), input_shape_dims.begin() + n, + 1ll, std::multiplies()); + DimensionVector result_dims; + result_dims.push_back(prefix_dim); + std::copy(input_shape_dims.begin() + n, input_shape_dims.end(), + std::back_inserter(result_dims)); + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + result_dims); +} + +llvm_ir::IrArray CollapseFirstNDims(llvm::IRBuilder<>* b, + const llvm_ir::IrArray& array, int64 n) { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + const Shape& shape = array.GetShape(); + CHECK(shape.has_layout() && + LayoutUtil::IsMonotonicWithDim0Major(shape.layout())); + CHECK_GE(shape.dimensions_size(), n); + Shape new_shape = CollapseFirstNDims(shape, n); + llvm::Value* new_value = b->CreateBitCast( + array.GetBasePointer(), + llvm_ir::ShapeToIrType(new_shape, module)->getPointerTo()); + return llvm_ir::IrArray(new_value, std::move(new_shape)); +} + +Status ValidateDotDimensionNumbers(const DotDimensionNumbers& dim_numbers) { + // Checks some invariants that do not hold in general, but DotDecomposer + // should have established for us. This is just a debugging aid. + TF_RET_CHECK(dim_numbers.lhs_contracting_dimensions_size() == 1); + std::vector batch_dim_numbers(dim_numbers.lhs_batch_dimensions_size()); + absl::c_iota(batch_dim_numbers, 0); + TF_RET_CHECK( + absl::c_equal(batch_dim_numbers, dim_numbers.lhs_batch_dimensions())); + TF_RET_CHECK( + absl::c_equal(batch_dim_numbers, dim_numbers.rhs_batch_dimensions())); + return Status::OK(); +} + +// Slice out the inner array at batch index `batch_index` from `outer_array`. +llvm_ir::IrArray SliceOutInnerArray(llvm_ir::IrArray outer_array, + llvm::Value* batch_index, + llvm::IRBuilder<>* b) { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + + Shape inner_shape = DropFirstDim(outer_array.GetShape()); + llvm_ir::IrArray::Index slice_index(b->getInt64Ty()); + slice_index.push_back(batch_index); + slice_index.InsertAt( + /*index=*/1, outer_array.GetShape().dimensions_size() - 1, + b->getInt64(0)); + llvm::Value* slice_ptr = outer_array.EmitArrayElementAddress(slice_index, b); + llvm::Type* slice_ptr_type = + llvm_ir::ShapeToIrType(inner_shape, module)->getPointerTo(); + return llvm_ir::IrArray(b->CreateBitCast(slice_ptr, slice_ptr_type), + std::move(inner_shape)); +} + +Status EmitBatchDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { + TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers())); + + // Lower a batch dot into a sequence of non-batch dot operations. + + int64 num_batch_dims = + dot.dot_dimension_numbers().lhs_batch_dimensions_size(); + + // First reshape the inputs to make sure we only have one batch dimension. + // This is a no-op bitcast because the operands have to be in row-major layout + // (enforced in CpuLayoutAssignment), and the batch dimensions are the leading + // dimensions (established by DotDecomposer and checked by + // ValidateDotDimensionNumbers above). + llvm_ir::IrArray lhs_array_reshaped = + CollapseFirstNDims(b, lhs_array, num_batch_dims); + llvm_ir::IrArray rhs_array_reshaped = + CollapseFirstNDims(b, rhs_array, num_batch_dims); + llvm_ir::IrArray target_array_reshaped = + CollapseFirstNDims(b, target_array, num_batch_dims); + + int64 batch_count = lhs_array_reshaped.GetShape().dimensions(0); + + KernelSupportLibrary ksl(b); + + return ksl.ForWithStatus( + "bdot", /*start=*/0, /*end=*/batch_count, /*step=*/1, + [&](llvm::Value* indvar) { + DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers(); + adjusted_dim_numbers.clear_lhs_batch_dimensions(); + adjusted_dim_numbers.clear_rhs_batch_dimensions(); + + // Create a DotInfo representing the "inner" non-batch dot operation. + DotInfo dot_info; + dot_info.lhs_shape = DropFirstDim(lhs_array_reshaped.GetShape()); + dot_info.rhs_shape = DropFirstDim(rhs_array_reshaped.GetShape()); + dot_info.result_shape = DropFirstDim(target_array_reshaped.GetShape()); + dot_info.dim_nums = dot.dot_dimension_numbers(); + dot_info.dim_nums.clear_lhs_batch_dimensions(); + dot_info.dim_nums.clear_rhs_batch_dimensions(); + + dot_info.dim_nums.set_lhs_contracting_dimensions( + 0, + dot_info.dim_nums.lhs_contracting_dimensions(0) - num_batch_dims); + dot_info.dim_nums.set_rhs_contracting_dimensions( + 0, + dot_info.dim_nums.rhs_contracting_dimensions(0) - num_batch_dims); + + llvm_ir::IrArray lhs_slice = + SliceOutInnerArray(lhs_array_reshaped, /*batch_index=*/indvar, b); + llvm_ir::IrArray rhs_slice = + SliceOutInnerArray(rhs_array_reshaped, /*batch_index=*/indvar, b); + llvm_ir::IrArray target_slice = SliceOutInnerArray( + target_array_reshaped, /*batch_index=*/indvar, b); + + // Emit the inner non-batch dot operation. + return EmitNonBatchDotOperation( + dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr, + executable_run_options_value, b, hlo_module_config, + target_machine_features); + }); +} + +bool IsBatchDot(const HloInstruction& instr) { + if (auto* dot_instr = DynCast(&instr)) { + return dot_instr->dot_dimension_numbers().lhs_batch_dimensions_size() > 0; + } + + return false; +} +} // namespace + +bool DotImplementationCanHandleTranspose( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features) { + DotImplementationStrategy impl_strategy = + GetDotImplementationStrategy(dot_instr.parent()->parent()->config(), + DotInfo(dot_instr), target_machine_features); + + // TODO(sanjoy): This is not quite right, it should be `impl_strategy == + // kEigen || impl_strategy == kTiledLlvmIrGemv || impl_strategy == + // kNaiveLlvmIr` but I'll fix this in a later CL in the interest of keeping + // the CL adding this comment NFC. + return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm || + impl_strategy == DotImplementationStrategy::kEigen; } +bool DotOperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features) { + DotImplementationStrategy impl_strategy = + GetDotImplementationStrategy(dot_instr.parent()->parent()->config(), + DotInfo(dot_instr), target_machine_features); + + return impl_strategy == DotImplementationStrategy::kTiledLlvmIrGemm || + impl_strategy == DotImplementationStrategy::kEigen; +} + +Status EmitDotOperation(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { + // This routine assumes that the dot operation is not in a parallelized + // enclosing computation. + CHECK(dot.parent()->root_instruction()->outer_dimension_partitions().empty()); + + if (IsBatchDot(dot)) { + TF_RET_CHECK(addend_array == nullptr); + return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array, + executable_run_options_value, b, + hlo_module_config, target_machine_features); + } + + return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array, + lhs_array, rhs_array, addend_array, + executable_run_options_value, b, + hlo_module_config, target_machine_features); +} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 4c2041b556aa8bf8fe8fb8e0674c0f4f04f0acae..105bd3005c86d87443b2528eba7b0106ad70590e 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -30,9 +30,16 @@ limitations under the License. namespace xla { namespace cpu { +// Returns true if the two operands and the output of `dot_instr` must have row +// major layout. +bool DotOperandsAndResultMustHaveRowMajorLayout( + const HloInstruction& dot_instr, + const TargetMachineFeatures& target_machine_features); -bool PotentiallyImplementedAsEigenDot( - const HloInstruction& hlo, +// Returns true our lowering strategy for `dot_instr` can fold in transposes to +// the either of the inputs. +bool DotImplementationCanHandleTranspose( + const HloInstruction& dot_instr, const TargetMachineFeatures& target_machine_features); // Returns the index for an operand to `hlo` that should ideally be column @@ -41,129 +48,24 @@ bool PotentiallyImplementedAsEigenDot( absl::optional ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo); -// Returns true to indicate that we can generate a tiled LLVM IR implementation -// for |dot|. -bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot); - -// Helper class for emitting LLVM IR to perform the dot operation. -class DotOpEmitter { - public: - // Emit LLVM IR to perform the dot operation on lhs_array and rhs_array and - // place the result in target_array. IR is emitted at current insert point of - // the builder. Upon completion of the method, the insert point is set to the - // end of all instructions emitted for this operation. - // - // If `addend_array` is not nullptr then it must be an array of the same - // dimensions as the result, and the result is computed as `addend_array` + - // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported - // for Matrix-vector products. - static Status EmitDotOperation( - const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features); - - private: - DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features); - - // Emits the IR to perform the dot operation. - Status Emit(); - - // Emits instructions to perform a scalar dot product (a multiply of the - // LHS and RHS) and store the results in the target. - Status EmitScalarDot(); - - // Emit an LLVM IR implementation of the dot operation if we can. Returns - // true if an LLVM IR implementation was emitted. - bool EmitLlvmIrDotIfProfitable(); - - // Emits a call to the CPU runtime to perform the matrix multiply. - Status EmitCallToRuntime(); - - // Represents the dimensions of a matrix-matrix multiply operation. - struct MatMultDims { - // The number of rows in the LHS. - int64 m; - - // The number of columns in the LHS, which is also must be equal to the - // number of rows in the RHS. - int64 k; - - // The number of columns on the RHS. - int64 n; - - // True if the LHS matrix is column major. - bool lhs_column_major; - - // True if the LHS contraction dimension is not 1. - bool lhs_non_canonical; - - // True if the RHS matrix is column major. - bool rhs_column_major; - - // True if the RHS contraction dimension is not 0. - bool rhs_non_canonical; - - // True if the result matrix is column major. - bool target_column_major; - }; - - // Get the MatMultDims instance for the dot product this DotOpEmitter - // represents. Precondition: the dot is of rank 2 (and thus its operands are - // of rank 2 as well). - MatMultDims GetMatMultDims() const; - - bool EmitSmallGemmIfProfitable(const MatMultDims& mat_mult_dims); - - // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector - // registers. - int64 GetGemvTilingFactor() const { - const int64 kDefaultTilingFactor = 8; - return options::LlvmIrGemvTilingFactor(hlo_module_config_) - .value_or(kDefaultTilingFactor); - } - - std::tuple GetGemmTileSize() const { - // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz - // - // TODO(b/80093688): Tune for other architectures and centralize this - // information in one place. - const std::tuple kDefaultTileSize = - std::tuple(11, 9, 1); - return options::LlvmIrGemmTileSize(hlo_module_config_) - .value_or(kDefaultTileSize); - } - - // Returns true if we should use an experimental implementation of GEMM - // (general matrix matrix multiplication) if possible. - bool EnableExperimentalLlvmIrGemm() const { - return options::EnableExperimentalLlvmIrGemm(hlo_module_config_); - } - - // Returns true if we should call into multi-threaded Eigen routines. - bool ShouldUseMultiThreadedEigen() { - return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); - } - - const HloInstruction& dot_; - const llvm_ir::IrArray& target_array_; - const llvm_ir::IrArray& lhs_array_; - const llvm_ir::IrArray& rhs_array_; - const llvm_ir::IrArray* addend_array_; - llvm::Value* executable_run_options_value_; - llvm::IRBuilder<>* b_; - const HloModuleConfig& hlo_module_config_; - const TargetMachineFeatures& target_machine_features_; -}; - +// Emit LLVM IR to perform the dot operation on lhs_array and rhs_array and +// place the result in target_array. IR is emitted at current insert point of +// the builder. Upon completion of the method, the insert point is set to the +// end of all instructions emitted for this operation. +// +// If `addend_array` is not nullptr then it must be an array of the same +// dimensions as the result, and the result is computed as `addend_array` + +// dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported +// for Matrix-vector products. +Status EmitDotOperation(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* b, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..cc28918ed60a8086135846e2b9b1b9d75ec31ef6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter_internal.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +// ----------------------------------------------------------------------------- +// INTERNAL HEADER. +// +// This file exposes internal implementation details from dot_op_emitter.cc for +// unit tests. Please do not depend on this! +// +// ----------------------------------------------------------------------------- + +namespace xla { +namespace cpu { +namespace internal { + +// Represents a dot operation. We use this in lieu of an `HloInstruction` +// because we want to be able to create this for the "inner" dot operation in a +// batch dot, for which there is no separate HLO instruction. +struct DotInfo { + Shape lhs_shape; + Shape rhs_shape; + Shape result_shape; + DotDimensionNumbers dim_nums; + + explicit DotInfo(const HloInstruction& instr) { + CHECK_EQ(instr.opcode(), HloOpcode::kDot); + lhs_shape = instr.operand(0)->shape(); + rhs_shape = instr.operand(1)->shape(); + result_shape = instr.shape(); + dim_nums = instr.dot_dimension_numbers(); + } +}; + +// Dictates how a dot operation is implemented. +enum class DotImplementationStrategy { + // The dot operation is lowered into LLVM IR that implements a naive nested + // loop that computes the result one element at a time. This is our + // "fallback"; we don't really want this to kick in for any non-trival dot + // operation. + kNaiveLlvmIr, + + // The dot operation is lowered into LLVM IR that implements a tiled + // Matrix*Vector operation. This strategy also allows fusing in a bias add + // into the dot. The matrix can be row major or column major, both are + // supported. + kTiledLlvmIrGemv, + + // The dot operation is lowered into LLVM IR that implemetns a tiled + // Matrix*Matrix operation. No fusions are supported. The two inputs + // and the output have to be row major. + kTiledLlvmIrGemm, + + // The dot operation is lowered into a call into an Eigen routine. No fusions + // are supported today. The two inputs and the output have to be row major. + // However, we do allow transposing either the LHS or the RHS as part of the + // GEMM -- we expose this flexibility as flexibility in the contraction + // dimensions, but we can also see this as flexibility in the input layouts. + kEigen, +}; + +// Returns the implementation strategy for a dot with the configuration +// `dot_info`. +DotImplementationStrategy GetDotImplementationStrategy( + const HloModuleConfig& config, const DotInfo& dot_info, + const TargetMachineFeatures& target_machine_features); +} // namespace internal +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_INTERNAL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index c8312d80bd5012e5bcb42a410db18a7fa77a2eb6..0028fbaed895becad8da496aa8acdf7dc173a2a0 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -51,10 +51,11 @@ StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, return Unimplemented("atan2"); } // Create a function declaration. - llvm::Function* function = - llvm::cast(module_->getOrInsertFunction( - llvm_ir::AsStringRef(function_name), lhs->getType(), lhs->getType(), - rhs->getType())); + llvm::Function* function = llvm::dyn_cast( + module_ + ->getOrInsertFunction(llvm_ir::AsStringRef(function_name), + lhs->getType(), lhs->getType(), rhs->getType()) + .getCallee()); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); @@ -85,9 +86,11 @@ StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, return Unimplemented("tanh"); } // Create a function declaration. - llvm::Function* function = llvm::cast( - module_->getOrInsertFunction(llvm_ir::AsStringRef(function_name), - value->getType(), value->getType())); + llvm::Function* function = llvm::dyn_cast( + module_ + ->getOrInsertFunction(llvm_ir::AsStringRef(function_name), + value->getType(), value->getType()) + .getCallee()); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 1a8bedfe6afb4f096ddd4703c312b84d521a7ba5..a8b139aec9e96b6bb580baf74789df7c998cebf8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -26,7 +26,7 @@ namespace cpu { int64 GetMinimumAlignmentForArray( const Shape& shape, const TargetMachineFeatures& target_machine_features) { - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); CHECK(!LayoutUtil::HasLayout(shape) || LayoutUtil::IsDense(shape.layout())); // We don't require a layout to be set on `shape`. This only works on CPU diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 4032c2da2f33ee61da8771ae6225a14172cbe6e8..efdda8599a1a66a0b2e43d17cfb35e3514e905b0 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,11 +24,9 @@ limitations under the License. #include #include +// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/platform/logging.h" -// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -70,6 +68,8 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/logging.h" namespace xla { @@ -111,10 +111,9 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order) { + absl::Span instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); - VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix - << "]; ordered? " << (instruction_order != nullptr); + VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; is_top_level_computation_ = is_top_level_computation; num_dynamic_loop_bounds_ = 0; if (!computation->root_instruction()->outer_dimension_partitions().empty()) { @@ -141,11 +140,7 @@ StatusOr IrEmitter::EmitComputation( bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; profiling_state_ = ProfilingState(use_rdtscp); - if (instruction_order == nullptr) { - TF_RETURN_IF_ERROR(computation->Accept(this)); - } else { - TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); - } + TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, instruction_order)); llvm::Function* ir_function = compute_function_->function(); InsertOrDie(&emitted_functions_, computation, ir_function); // Delete 'compute_function', finalizing 'ir_function' and restoring caller @@ -228,11 +223,11 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) { } Status IrEmitter::HandleCopy(HloInstruction* copy) { - if (ShapeUtil::IsTuple(copy->shape())) { + if (copy->shape().IsTuple()) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); return EmitMemcpy(*(copy->operand(0)), *copy); - } else if (ShapeUtil::IsArray(copy->shape())) { + } else if (copy->shape().IsArray()) { // Use the elemental emitter for array shapes. return DefaultAction(copy); } @@ -244,10 +239,12 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); DCHECK_GE(byte_size, 0); - // Largest scalar is a complex64 so we don't need to worry about the + // Largest scalar is a complex128 so we don't need to worry about the // int64->int truncation here. - DCHECK_LE(byte_size, 8); - return byte_size; + DCHECK_LE(byte_size, 16); + + // Allocations may be 8-byte aligned if part of a small block. + return std::min(8LL, byte_size); } int64 IrEmitter::ByteSizeOf(const Shape& shape) const { @@ -321,7 +318,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { auto on_false = tuple_select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); - TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + TF_RET_CHECK(tuple_select->shape().IsTuple()); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select)); llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), @@ -351,7 +348,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, module_); - if (ShapeUtil::IsTuple(data_shape)) { + if (data_shape.IsTuple()) { TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape)); // For a tuple, we first copy each of the internal elements to @@ -415,11 +412,18 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Function* acquire_func; if (kind == XfeedKind::kInfeed) { - acquire_func = llvm::cast(module_->getOrInsertFunction( - runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)); + acquire_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type) + .getCallee()); } else { - acquire_func = llvm::cast(module_->getOrInsertFunction( - runtime::kAcquireOutfeedBufferForPopulationSymbolName, acquire_type)); + acquire_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kAcquireOutfeedBufferForPopulationSymbolName, + acquire_type) + .getCallee()); } acquire_func->setCallingConv(llvm::CallingConv::C); @@ -432,11 +436,19 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Function* release_func; if (kind == XfeedKind::kInfeed) { - release_func = llvm::cast(module_->getOrInsertFunction( - runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type)); + release_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kReleaseInfeedBufferAfterDequeueSymbolName, + release_type) + .getCallee()); } else { - release_func = llvm::cast(module_->getOrInsertFunction( - runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, release_type)); + release_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, + release_type) + .getCallee()); } release_func->setCallingConv(llvm::CallingConv::C); @@ -475,7 +487,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { const Shape& operand_shape = operand->shape(); llvm::Value* value = GetEmittedValueFor(operand); - if (!ShapeUtil::IsTuple(operand_shape)) { + if (!operand_shape.IsTuple()) { return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value); } @@ -498,6 +510,26 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { const HloSortInstruction* sort = Cast(hlo); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); Shape keys_shape = sort->keys()->shape(); + PrimitiveType keys_type = keys_shape.element_type(); + switch (keys_type) { + case PRED: + case S8: + case U8: + case S16: + case U16: + case F16: + case S32: + case U32: + case F32: + case S64: + case U64: + case F64: + break; + default: + return Unimplemented( + "Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); + } std::vector destination_addresses(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { ShapeIndex shape_index = @@ -540,110 +572,108 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { higher_dimensions *= normalized_keys_shape.dimensions(i); } int64 lower_dimensions = 1; - for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1; + for (int64 i = normalized_keys_shape.rank() - 1; i > physical_dimension_to_sort; --i) { lower_dimensions *= normalized_keys_shape.dimensions(i); } - PrimitiveType keys_type = keys_shape.element_type(); - const char* fn_name = nullptr; - llvm::Type* keys_native_type = nullptr; - switch (keys_type) { - case PRED: - fn_name = runtime::kKeyValueSortPREDSymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case S8: - fn_name = runtime::kKeyValueSortS8SymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case U8: - fn_name = runtime::kKeyValueSortU8SymbolName; - keys_native_type = b_.getInt8PtrTy(); - break; - case S16: - fn_name = runtime::kKeyValueSortS16SymbolName; - keys_native_type = b_.getInt16Ty()->getPointerTo(); - break; - case U16: - fn_name = runtime::kKeyValueSortU16SymbolName; - keys_native_type = b_.getInt16Ty()->getPointerTo(); - break; - case F16: - fn_name = runtime::kKeyValueSortF16SymbolName; - keys_native_type = b_.getHalfTy()->getPointerTo(); - break; - case S32: - fn_name = runtime::kKeyValueSortS32SymbolName; - keys_native_type = b_.getInt32Ty()->getPointerTo(); - break; - case U32: - fn_name = runtime::kKeyValueSortU32SymbolName; - keys_native_type = b_.getInt32Ty()->getPointerTo(); - break; - case F32: - fn_name = runtime::kKeyValueSortF32SymbolName; - keys_native_type = b_.getFloatTy()->getPointerTo(); - break; - case S64: - fn_name = runtime::kKeyValueSortS64SymbolName; - keys_native_type = b_.getInt64Ty()->getPointerTo(); - break; - case U64: - fn_name = runtime::kKeyValueSortU64SymbolName; - keys_native_type = b_.getInt64Ty()->getPointerTo(); - break; - case F64: - fn_name = runtime::kKeyValueSortF64SymbolName; - keys_native_type = b_.getDoubleTy()->getPointerTo(); - break; - default: - return Unimplemented( - "Element type %s not supported in the Sort op on CPU.", - PrimitiveType_Name(keys_type)); + llvm::FunctionType* less_than_type = llvm::FunctionType::get( + b_.getInt1Ty(), {b_.getInt8PtrTy(), b_.getInt8PtrTy()}, + /*isVarArg=*/false); + auto less_than_function = llvm_ir::CreateFunction( + less_than_type, llvm::GlobalValue::InternalLinkage, + /*enable_fast_math=*/false, + /*optimize_for_size=*/true, absl::StrCat(IrName(sort), "_comparator"), + module_); + // Emit the code for the less_than function. + { + llvm::IRBuilder<>::InsertPointGuard guard(b_); + + auto* entry_bb = + llvm::BasicBlock::Create(b_.getContext(), "entry", less_than_function); + + b_.SetInsertPoint(entry_bb); + auto keys_ir_type = llvm_ir::PrimitiveTypeToIrType(keys_type, module_); + CHECK_EQ(less_than_function->arg_size(), 2); + llvm::Value* keys_lhs_ptr = less_than_function->arg_begin(); + keys_lhs_ptr = PointerCast(keys_lhs_ptr, keys_ir_type->getPointerTo()); + llvm::Value* keys_rhs_ptr = less_than_function->arg_begin() + 1; + keys_rhs_ptr = PointerCast(keys_rhs_ptr, keys_ir_type->getPointerTo()); + + // TODO(b/122298745): Replace the custom compare logic with a call to the + // computation specified for the Sort op. + llvm::Value* keys_lhs = Load(keys_ir_type, keys_lhs_ptr); + llvm::Value* keys_rhs = Load(keys_ir_type, keys_rhs_ptr); + bool is_signed_comparison = true; + if (primitive_util::IsFloatingPointType(keys_type)) { + // We would like a total order of floating point numbers so that the + // sort has a predictable behavior in the presence of NaNs. Rather + // than using floating point comparison, we use the following trick: + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? 0x7FFFFFFF - x : x; + // then y is ordered as an int32 such that finite values have the + // obvious order, -0 is ordered before 0, and -NaN and NaN appear at + // the beginning and end of the ordering. + auto k = b_.getInt(llvm::APInt::getSignedMaxValue( + keys_lhs->getType()->getPrimitiveSizeInBits())); + auto comparison_type = k->getType(); + auto zero = llvm::ConstantInt::get(comparison_type, 0); + auto maybe_flip = [&](llvm::Value* v) { + return b_.CreateSelect(b_.CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), + b_.CreateSub(k, v), v); + }; + keys_lhs = b_.CreateBitCast(keys_lhs, comparison_type); + keys_rhs = b_.CreateBitCast(keys_rhs, comparison_type); + keys_lhs = maybe_flip(keys_lhs); + keys_rhs = maybe_flip(keys_rhs); + } else if (!primitive_util::IsSignedIntegralType(keys_type)) { + is_signed_comparison = false; + } + llvm::Value* result = + b_.CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + keys_lhs, keys_rhs); + llvm::ReturnInst::Create(b_.getContext(), + /*retVal=*/result, entry_bb); } llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( b_.getVoidTy(), - {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), + {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), - b_.getInt32Ty()->getPointerTo()}, + b_.getInt32Ty()->getPointerTo(), less_than_function->getType()}, /*isVarArg=*/false); - auto* key_value_sort_func = llvm::cast( - module_->getOrInsertFunction(fn_name, key_value_sort_type)); + auto* key_value_sort_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction(runtime::kKeyValueSortSymbolName, + key_value_sort_type) + .getCallee()); key_value_sort_func->setCallingConv(llvm::CallingConv::C); key_value_sort_func->setDoesNotThrow(); - llvm::Value* values; - llvm::Value* sizes; - if (sort->values_count() == 0) { - values = llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()); - sizes = llvm::Constant::getNullValue(b_.getInt32Ty()->getPointerTo()); - } else { - values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt8PtrTy(), b_.getInt32(sort->values_count()), - "cc_values_alloca", &b_); - sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - b_.getInt32Ty(), b_.getInt32(sort->values_count()), "cc_sizes_alloca", - &b_); - for (int64 i = 0; i < sort->values_count(); ++i) { - llvm::Value* value_as_i8ptr = - PointerCast(destination_addresses[i + 1], b_.getInt8PtrTy()); - llvm::Value* slot_in_values_alloca = - ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); - Store(value_as_i8ptr, slot_in_values_alloca); - llvm::Value* slot_in_sizes_alloca = - ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); - llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( - sort->operand(i + 1)->shape().element_type())); - Store(size, slot_in_sizes_alloca); - } + llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt8PtrTy(), b_.getInt32(sort->operand_count()), "cc_values_alloca", + &b_); + llvm::Value* sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt32Ty(), b_.getInt32(sort->operand_count()), "cc_sizes_alloca", + &b_); + for (int64 i = 0; i < sort->operand_count(); ++i) { + llvm::Value* value_as_i8ptr = + PointerCast(destination_addresses[i], b_.getInt8PtrTy()); + llvm::Value* slot_in_values_alloca = + ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); + Store(value_as_i8ptr, slot_in_values_alloca); + llvm::Value* slot_in_sizes_alloca = + ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); + llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( + sort->operand(i)->shape().element_type())); + Store(size, slot_in_sizes_alloca); } Call(key_value_sort_func, - {PointerCast(destination_addresses[0], keys_native_type), - b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), + {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), b_.getInt64(lower_dimensions), values, - b_.getInt32(sort->values_count()), sizes}); + b_.getInt32(sort->operand_count()), sizes, less_than_function}); if (sort->values_count() > 0) { llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_, @@ -752,11 +782,6 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( } Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { - TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*reduce_window, - /*operands=*/{reduce_window->operand(0)}, - /*supported_types=*/{F32, BF16, S32, F16})); - // Pseudo code for reduce window: // // for (coordinates O in the output) @@ -784,8 +809,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { const auto init_value = select_and_scatter->operand(2); const Window& window = select_and_scatter->window(); PrimitiveType operand_element_type = operand->shape().element_type(); - const int64 rank = ShapeUtil::Rank(operand->shape()); - CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); + const int64 rank = operand->shape().rank(); + CHECK_EQ(rank, source->shape().rank()); CHECK_EQ(rank, window.dimensions_size()); // TODO(b/31410564): Implement dilation for select-and-scatter. @@ -947,12 +972,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { auto rhs = dot->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, F64, C64})); + /*supported_types=*/{F16, F32, F64, C64, C128})); const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 || - dnums.rhs_batch_dimensions_size() > 0) { - return Unimplemented("Dot with batch dimensions not implemented."); - } if (dnums.lhs_contracting_dimensions_size() != 1) { // This is disallowed by ShapeInference today. @@ -975,10 +996,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { << llvm_ir::DumpToString(*target_array.GetBasePointer()); // Dot operation is complicated so we delegate to a helper class. - return DotOpEmitter::EmitDotOperation( - *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, - GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, - target_machine_features_); + return EmitDotOperation(*dot, target_array, lhs_array, rhs_array, + /*addend_array=*/nullptr, + GetExecutableRunOptionsArgument(), &b_, + hlo_module_config_, target_machine_features_); } StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( @@ -1123,7 +1144,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { auto rhs = convolution->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, C64})); + /*supported_types=*/{F16, F32, C64, C128})); // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support // different data layouts. @@ -1236,8 +1257,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded " "conv2d function."; } - llvm::Function* conv_func = llvm::cast( - module_->getOrInsertFunction(fn_name, conv_type)); + llvm::Function* conv_func = llvm::dyn_cast( + module_->getOrInsertFunction(fn_name, conv_type).getCallee()); conv_func->setCallingConv(llvm::CallingConv::C); conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); @@ -1320,8 +1341,8 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { ? runtime::kEigenFftSymbolName : runtime::kEigenSingleThreadedFftSymbolName; - llvm::Function* fft_func = llvm::cast( - module_->getOrInsertFunction(fn_name, fft_type)); + llvm::Function* fft_func = llvm::dyn_cast( + module_->getOrInsertFunction(fn_name, fft_type).getCallee()); fft_func->setCallingConv(llvm::CallingConv::C); fft_func->setDoesNotThrow(); fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); @@ -1338,11 +1359,11 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { return Status::OK(); } -Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { +Status IrEmitter::HandleAllReduce(HloInstruction* crs) { if (hlo_module_config_.replica_count() != 1) { // TODO(b/33011107): Support nontrivial cross replica sum on CPU. return Unimplemented( - "CrossReplicaSum with >1 replica is not implemented on CPU."); + "AllReduce with >1 replica is not implemented on CPU."); } // When there is a single replica, a cross replica sum is the identity @@ -1367,8 +1388,8 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { assignment_.GetUniqueSlice(crs, {i})); const Shape& operand_shape = crs->operand(i)->shape(); - CHECK(ShapeUtil::IsArray(operand_shape)) - << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + CHECK(operand_shape.IsArray()) + << "Operands to all-reduce must be arrays: " << crs->ToString(); operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. @@ -1404,7 +1425,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { int64 delta = 0; for (int64 i = 0; i < operand_shape.dimensions_size(); i++) { - if (reduced_dims.count(i)) { + if (reduced_dims.contains(i)) { delta++; } else { InsertOrDie(&unreduced_dim_map, i, i - delta); @@ -1417,7 +1438,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { for (int64 operand_dim_idx = 0; operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) { int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx); - if (!reduced_dims.count(operand_dim)) { + if (!reduced_dims.contains(operand_dim)) { if (FindOrDie(unreduced_dim_map, operand_dim) != result_shape.layout().minor_to_major(result_dim_idx++)) { return false; @@ -1714,10 +1735,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( vectorization_factor_in_bytes / ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()); - bool is_reduction_over_minor_dimension = - std::find(dimensions.begin(), dimensions.end(), - LayoutUtil::Minor(arg->shape().layout(), 0)) != - dimensions.end(); + bool is_reduction_over_minor_dimension = absl::c_linear_search( + dimensions, LayoutUtil::Minor(arg->shape().layout(), 0)); unsigned element_alignment = tensorflow::MathUtil::GCD( ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), @@ -1729,7 +1748,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( return false; } - CHECK(!ShapeUtil::IsTuple(reduce->shape())); + CHECK(!reduce->shape().IsTuple()); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce)); // We know we're not reducing over the most minor dimension, which means we @@ -1896,7 +1915,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( Status IrEmitter::HandleReduce(HloInstruction* reduce) { // TODO(b/112040122): Support variadic reduce. - if (!ShapeUtil::IsArray(reduce->shape())) { + if (!reduce->shape().IsArray()) { return Unimplemented("Variadic reduce is not supported on CPU"); } auto arg = reduce->mutable_operand(0); @@ -1995,7 +2014,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // The memcpy will copy elements that are logically this shape (allowed to be // scalar). const Shape logical_element_shape = ShapeUtil::FilterDimensions( - [&inner_dims](int64 dim) -> bool { return inner_dims.count(dim); }, + [&inner_dims](int64 dim) { return inner_dims.contains(dim); }, operand->shape()); const int64 primitive_elements_per_logical_element = @@ -2210,10 +2229,10 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { llvm_ir::IrArray addend_array( GetIrArrayFor(fusion->operand(addend_param_number))); - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, target_array, lhs_array, rhs_array, &addend_array, - GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, - target_machine_features_)); + TF_RETURN_IF_ERROR( + EmitDotOperation(*dot, target_array, lhs_array, rhs_array, + &addend_array, GetExecutableRunOptionsArgument(), &b_, + hlo_module_config_, target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); @@ -2262,15 +2281,32 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { InBoundsGEP(operands_alloca, {b_.getInt64(i)}); Store(operand_as_i8ptr, slot_in_operands_alloca); } - auto* custom_call_ir_function = - llvm::cast(module_->getOrInsertFunction( - AsStringRef(custom_call_target), - llvm::FunctionType::get( - /*Result=*/b_.getVoidTy(), - /*Params=*/{i8_ptr_type, operands_alloca->getType()}, - /*isVarArg=*/false))); + auto* custom_call_ir_function = llvm::dyn_cast( + module_ + ->getOrInsertFunction( + AsStringRef(custom_call_target), + llvm::FunctionType::get( + /*Result=*/b_.getVoidTy(), + /*Params=*/{i8_ptr_type, operands_alloca->getType()}, + /*isVarArg=*/false)) + .getCallee()); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); + // Write the tuple table if the output is a tuple. + if (custom_call->shape().IsTuple()) { + std::vector base_ptrs; + for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape()); + ++i) { + const Shape& elem_shape = + ShapeUtil::GetTupleElementShape(custom_call->shape(), i); + TF_RET_CHECK(!elem_shape.IsTuple()) << "Nested tuples not implemented"; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(custom_call, {i})); + llvm::Value* addr = EmitBufferPointer(slice, elem_shape); + base_ptrs.push_back(addr); + } + llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_, module_); + } auto* output_address_arg = PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); @@ -2391,8 +2427,7 @@ StatusOr IrEmitter::EmitFastConcatenate( int64 concat_dim = concatenate->dimensions(0); const Layout& output_layout = output_shape.layout(); auto output_min2maj = LayoutUtil::MinorToMajor(output_layout); - auto concat_dim_layout_itr = - std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim); + auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim); std::vector inner_dims(output_min2maj.begin(), concat_dim_layout_itr); std::vector outer_dims(std::next(concat_dim_layout_itr), @@ -2792,7 +2827,7 @@ llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); llvm::LoadInst* param_address_untyped = Load(param_address_offset); - if (!ShapeUtil::IsOpaque(target_shape)) { + if (!target_shape.IsOpaque()) { AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); AttachDereferenceableMetadataForLoad(param_address_untyped, target_shape); @@ -2851,7 +2886,9 @@ llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice, if (slice.allocation()->is_thread_local()) { return EmitThreadLocalBufferPointer(slice, target_shape); } else if (slice.allocation()->is_constant()) { - return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); + return BitCast( + FindOrDie(constant_buffer_to_global_, slice.allocation()->index()), + IrShapeType(target_shape)->getPointerTo()); } else { return EmitGlobalBufferPointer(slice, target_shape); } @@ -2944,8 +2981,7 @@ Status IrEmitter::ElementTypesSameAndSupported( TF_RET_CHECK(!operands.empty()); PrimitiveType primitive_type = operands[0]->shape().element_type(); - if (std::find(supported_types.begin(), supported_types.end(), - primitive_type) == supported_types.end()) { + if (!absl::c_linear_search(supported_types, primitive_type)) { return Unimplemented("unsupported operand type %s in op %s", PrimitiveType_Name(primitive_type), HloOpcodeString(instruction.opcode())); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 559a8162a2d53f28ea6817653503c216af90a610..974dd7cd3f2254bfbc86fffae02c06c481af8902 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order); + absl::Span instruction_order); llvm::IRBuilder<>* b() { return &b_; } @@ -134,7 +134,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSort(HloInstruction* sort) override; @@ -250,14 +250,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); - // Emits a function into the current module. This can be used for - // computations embedded inside other computations, such as the - // function that a map operation applies. - StatusOr EmitFunction( - HloComputation* function, // The function to emit. - absl::string_view - function_name_suffix); // Used for LLVM IR register names. - // Emits a call to a thread local function (e.g. to the computation nested // within a reduce or a map). Thread local callees (by definition) only write // to and read from thread local allocations. @@ -448,7 +440,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, computation_to_profile_idx_; // Maps HLOs to Values emitted for them. - std::unordered_map emitted_value_; + absl::flat_hash_map emitted_value_; llvm_ir::AliasAnalysis alias_analysis_; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index adfb8392bf6fa356f0a5cdab3ff74036eca8918e..84a5b058cfb11c899eb6ae03478ed550b84dc819 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -266,9 +266,11 @@ Status EmitCallToParallelForkJoin( /*Params=*/compute_function_params, /*isVarArg=*/false); - llvm::Function* fork_join_func = - llvm::cast(module->getOrInsertFunction( - runtime::kParallelForkJoinSymbolName, fork_join_type)); + llvm::Function* fork_join_func = llvm::dyn_cast( + module + ->getOrInsertFunction(runtime::kParallelForkJoinSymbolName, + fork_join_type) + .getCallee()); fork_join_func->setCallingConv(llvm::CallingConv::C); fork_join_func->setDoesNotThrow(); diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index f9722ffadac801521ddcbb568dd4435fd02e951b..93ef51754d21ad3ff4e24298c89649ef4c2742fb 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -36,57 +36,88 @@ const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX"; const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX"; namespace { -llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, - llvm::StringRef function_name, - int vector_width, - bool enable_fast_math) { - llvm::Function* vector_tanh_function = module->getFunction(function_name); - if (vector_tanh_function == nullptr) { + +// Replaces calls to the function `fn_name` with the code generated by +// fn_body_generator. +// +// We assume that fn_name accepts either a scalar f32 or a vector of +// vector_width f32s, and that fn_body_generator generates a function body with +// the same inputs/outputs as fn_name. +void RewriteCalls( + llvm::Module* module, const char* fn_name, + std::function* b, llvm::Value* input, + int32 vector_width)> + fn_body_generator, + int32 vector_width, bool enable_fast_math) { + llvm::Function* fn = module->getFunction(fn_name); + if (fn == nullptr) { // If the function declaration is not present in the module, there can't be // any calls to resolve. Don't emit the function in this case. - return nullptr; + return; } - llvm::LLVMContext* context = &module->getContext(); + // Our task is to generate a function body for `fn`, but we can't generate a + // function body for an LLVM intrinsic. So if fn is an intrinsic, replace it + // with a new function. + if (fn->isIntrinsic()) { + llvm::Function* new_fn = llvm::Function::Create( + fn->getFunctionType(), llvm::GlobalValue::InternalLinkage, + llvm::Twine("xla_impl.") + fn_name, module); + fn->replaceAllUsesWith(new_fn); + fn->eraseFromParent(); + fn = new_fn; + } - llvm::BasicBlock* vector_tanh_body = - llvm::BasicBlock::Create(*context, "body", vector_tanh_function); + llvm::LLVMContext* context = &module->getContext(); - llvm::IRBuilder<> b(vector_tanh_body); + llvm::BasicBlock* fn_body = llvm::BasicBlock::Create(*context, "body", fn); + llvm::IRBuilder<> b(fn_body); llvm::FastMathFlags fast_math_flags; fast_math_flags.setFast(enable_fast_math); b.setFastMathFlags(fast_math_flags); - llvm::Value* input = &*vector_tanh_function->arg_begin(); - CHECK_EQ(vector_width, input->getType()->getVectorNumElements()); - b.CreateRet(llvm_ir::EmitFastTanh(&b, input)); - - DCHECK(!llvm::verifyFunction(*vector_tanh_function)); - return vector_tanh_function; -} + llvm::Value* input = &*fn->arg_begin(); -llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, - llvm::StringRef function_name, - int vector_width, - bool enable_fast_math) { - llvm::Function* vector_exp_function = module->getFunction(function_name); - if (vector_exp_function == nullptr) { - // If the function declaration is not present in the module, there can't be - // any calls to resolve. Don't emit the function in this case. - return nullptr; + // Upcast to vector type if input is a scalar. + if (vector_width == 1) { + llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1); + input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input, + uint64_t{0}); } - llvm::LLVMContext* context = &module->getContext(); + // Generate the vectorized code. + CHECK_EQ(vector_width, input->getType()->getVectorNumElements()); + llvm::Value* result = fn_body_generator(&b, input, vector_width); + + // Downcast result to scalar type if necessary. + if (vector_width == 1) { + result = b.CreateExtractElement(result, uint64_t{0}); + } + b.CreateRet(result); + DCHECK(!llvm::verifyFunction(*fn)); - llvm::BasicBlock* vector_exp_body = - llvm::BasicBlock::Create(*context, "body", vector_exp_function); + // Force-inline `fn` into all of its callers and then delete `fn`. + // + // TODO(b/73081976): Should we avoid inlining these in some cases? + std::vector calls_to_inline; + for (auto* user : fn->users()) { + calls_to_inline.push_back(llvm::cast(user)); + } + for (auto* call_to_inline : calls_to_inline) { + llvm::InlineFunctionInfo inline_function_info; + CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); + } + fn->eraseFromParent(); +} - llvm::IRBuilder<> b(vector_exp_body); - llvm::FastMathFlags fast_math_flags; - fast_math_flags.setFast(); - b.setFastMathFlags(fast_math_flags); +llvm::Value* GenerateVF32Tanh(llvm::IRBuilder<>* b, llvm::Value* input, + int32 /*vector_width*/) { + return llvm_ir::EmitFastTanh(b, input); +} - VectorSupportLibrary vsl(F32, vector_width, &b, "exp_f32"); +llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input, + int32 vector_width) { + VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32"); // This implements the same polynomial approximation as implemented in Eigen3. @@ -107,7 +138,6 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1); const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1); - llvm::Value* input = &*vector_exp_function->arg_begin(); llvm::Value* input_clamped = vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi); llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half)); @@ -128,49 +158,24 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. llvm::Value* vector_constant_0x7f = - b.CreateVectorSplat(vector_width, b.getInt32(0x7f)); + b->CreateVectorSplat(vector_width, b->getInt32(0x7f)); llvm::Value* vector_constant_23 = - b.CreateVectorSplat(vector_width, b.getInt32(23)); + b->CreateVectorSplat(vector_width, b->getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(b.getInt32Ty(), vector_width); + llvm::VectorType::get(b->getInt32Ty(), vector_width); // fx is clamped so we don't have to worry about it being out of range for // i32. - llvm::Value* emm0 = b.CreateFPToSI(fx, i32_vector_type); - emm0 = b.CreateAdd(emm0, vector_constant_0x7f); - emm0 = b.CreateShl(emm0, vector_constant_23); - llvm::Value* emm0_f32 = b.CreateBitCast(emm0, vsl.vector_type()); - - llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input); + llvm::Value* emm0 = b->CreateFPToSI(fx, i32_vector_type); + emm0 = b->CreateAdd(emm0, vector_constant_0x7f); + emm0 = b->CreateShl(emm0, vector_constant_23); + llvm::Value* emm0_f32 = b->CreateBitCast(emm0, vsl.vector_type()); - b.CreateRet(result); - - DCHECK(!llvm::verifyFunction(*vector_exp_function)); - return vector_exp_function; + return vsl.Max(vsl.Mul(y, emm0_f32), input); } -llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, - llvm::StringRef function_name, - int vector_width, - bool enable_fast_math) { - llvm::Function* vector_log_function = module->getFunction(function_name); - if (vector_log_function == nullptr) { - // If the function declaration is not present in the module, there can't be - // any calls to resolve. Don't emit the function in this case. - return nullptr; - } - - llvm::LLVMContext* context = &module->getContext(); - - llvm::BasicBlock* vector_log_body = - llvm::BasicBlock::Create(*context, "body", vector_log_function); - - llvm::IRBuilder<> b(vector_log_body); - llvm::FastMathFlags fast_math_flags; - fast_math_flags.setFast(); - b.setFastMathFlags(fast_math_flags); - - llvm::Value* input = &*vector_log_function->arg_begin(); - VectorSupportLibrary vsl(F32, vector_width, &b, "log_f32"); +llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input, + int32 vector_width) { + VectorSupportLibrary vsl(F32, vector_width, b, "log_f32"); const llvm::APFloat half = GetIeeeF32(0.5); const llvm::APFloat one = GetIeeeF32(1.0); @@ -193,129 +198,107 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, // The smallest non denormalized float number. const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000); const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000); + const llvm::APFloat pos_inf = GetIeeeF32FromBitwiseRep(0x7f800000); const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000); // invalid_mask is set if x is negative or NaN (and therefore output // must be NaN). llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector()); - llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector()); + llvm::Value* is_zero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector()); + llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf); // Cut off denormalized stuff. - input = vsl.Max(min_norm_pos, input); + llvm::Value* tmp0 = vsl.Max(min_norm_pos, input); // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. llvm::Value* vector_constant_0x7f = - b.CreateVectorSplat(vector_width, b.getInt32(0x7f)); + b->CreateVectorSplat(vector_width, b->getInt32(0x7f)); llvm::Value* vector_constant_23 = - b.CreateVectorSplat(vector_width, b.getInt32(23)); + b->CreateVectorSplat(vector_width, b->getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(b.getInt32Ty(), vector_width); + llvm::VectorType::get(b->getInt32Ty(), vector_width); - llvm::Value* emm0 = - b.CreateLShr(b.CreateBitCast(input, i32_vector_type), vector_constant_23); + llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type), + vector_constant_23); // Keep only the fractional part. - input = vsl.FloatAnd(input, inv_mant_mask); - input = vsl.FloatOr(input, half); + tmp0 = vsl.FloatAnd(tmp0, inv_mant_mask); + tmp0 = vsl.FloatOr(tmp0, half); - emm0 = b.CreateSub(emm0, vector_constant_0x7f); - llvm::Value* e = vsl.Add(one, b.CreateSIToFP(emm0, vsl.vector_type())); + emm0 = b->CreateSub(emm0, vector_constant_0x7f); + llvm::Value* e = vsl.Add(one, b->CreateSIToFP(emm0, vsl.vector_type())); // part2: // if( x < SQRTHF ) { // e -= 1; // x = x + x - 1.0; // } else { x = x - 1.0; } - llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF); - llvm::Value* tmp = vsl.FloatAnd(input, mask); - input = vsl.Sub(input, one); + llvm::Value* mask = vsl.FCmpOLTMask(tmp0, cephes_SQRTHF); + llvm::Value* tmp1 = vsl.FloatAnd(tmp0, mask); + tmp0 = vsl.Sub(tmp0, one); e = vsl.Sub(e, vsl.FloatAnd(mask, one)); - input = vsl.Add(input, tmp); + tmp0 = vsl.Add(tmp0, tmp1); - llvm::Value* x2 = vsl.Mul(input, input); - llvm::Value* x3 = vsl.Mul(x2, input); + llvm::Value* x2 = vsl.Mul(tmp0, tmp0); + llvm::Value* x3 = vsl.Mul(x2, tmp0); llvm::Value *y, *y1, *y2; - y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1); - y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4); - y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7); - y = vsl.MulAdd(y, input, cephes_log_p2); - y1 = vsl.MulAdd(y1, input, cephes_log_p5); - y2 = vsl.MulAdd(y2, input, cephes_log_p8); + y = vsl.MulAdd(tmp0, cephes_log_p0, cephes_log_p1); + y1 = vsl.MulAdd(tmp0, cephes_log_p3, cephes_log_p4); + y2 = vsl.MulAdd(tmp0, cephes_log_p6, cephes_log_p7); + y = vsl.MulAdd(y, tmp0, cephes_log_p2); + y1 = vsl.MulAdd(y1, tmp0, cephes_log_p5); + y2 = vsl.MulAdd(y2, tmp0, cephes_log_p8); y = vsl.MulAdd(y, x3, y1); y = vsl.MulAdd(y, x3, y2); y = vsl.Mul(y, x3); y1 = vsl.Mul(cephes_log_q1, e); - tmp = vsl.Mul(half, x2); + llvm::Value* tmp2 = vsl.Mul(half, x2); y = vsl.Add(y, y1); - input = vsl.Sub(input, tmp); + tmp0 = vsl.Sub(tmp0, tmp2); y2 = vsl.Mul(cephes_log_q2, e); - input = vsl.Add(input, y); - input = vsl.Add(input, y2); + tmp0 = vsl.Add(tmp0, y); + tmp0 = vsl.Add(tmp0, y2); - // Negative arg will be NAN, 0 will be -INF. - llvm::Value* or_lhs = - vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask)); - llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf); - llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs); + // Contains +/-inf where +/-inf is the correct answer, otherwise 0. + llvm::Value* result_inf = vsl.FloatOr(vsl.FloatAnd(is_zero_mask, minus_inf), + vsl.FloatAnd(is_pos_inf_mask, pos_inf)); - b.CreateRet(result); + // Contains a finite result or nan. This is the correct answer only if both + // result_minus_inf and result_pos_inf are both 0. + // + // (This implementation works because 0xffffffff is a nan.) + llvm::Value* result_finite_or_nan = vsl.FloatOr(tmp0, invalid_mask); - DCHECK(!llvm::verifyFunction(*vector_log_function)); - return vector_log_function; + // Combine the above into a final result. + return vsl.FloatOr(result_inf, + vsl.FloatAndNot(vsl.FloatOr(is_zero_mask, is_pos_inf_mask), + result_finite_or_nan)); } } // namespace void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { - auto* tanh_v4f32 = - EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName, - /*vector_width=*/4, enable_fast_math); - auto* tanh_v8f32 = - EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName, - /*vector_width=*/8, enable_fast_math); - - auto* exp_v4f32 = - EmitVectorF32ExpIfNeeded(module, kExpV4F32SymbolName, - /*vector_width=*/4, enable_fast_math); - auto* exp_v8f32 = - EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName, - /*vector_width=*/8, enable_fast_math); - - auto* log_v4f32 = - EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName, - /*vector_width=*/4, enable_fast_math); - auto* log_v8f32 = - EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName, - /*vector_width=*/8, enable_fast_math); - - // Gather all the call sites, force inline them and then delete the vector - // function bodies. - // - // TODO(b/73081976): Should we avoid inlining these intrinsics in some cases? - - std::vector calls_to_inline; - for (auto* function : - {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { - if (function != nullptr) { - for (auto* user : function->users()) { - calls_to_inline.push_back(llvm::cast(user)); - } - } - } - - for (auto* call_to_inline : calls_to_inline) { - llvm::InlineFunctionInfo inline_function_info; - CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); - } - - for (auto* function : - {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { - if (function != nullptr) { - function->eraseFromParent(); - } - } + // Curry some params to RewriteCalls. + auto rewrite_calls = + std::bind(RewriteCalls, module, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, enable_fast_math); + + rewrite_calls("tanhf", GenerateVF32Tanh, /*vector_width=*/1); + rewrite_calls("llvm.tanh.f32", GenerateVF32Tanh, /*vector_width=*/1); + rewrite_calls(kTanhV4F32SymbolName, GenerateVF32Tanh, /*vector_width=*/4); + rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8); + + rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1); + rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1); + rewrite_calls(kExpV4F32SymbolName, GenerateVF32Exp, /*vector_width=*/4); + rewrite_calls(kExpV8F32SymbolName, GenerateVF32Exp, /*vector_width=*/8); + + rewrite_calls("logf", GenerateVF32Log, /*vector_width=*/1); + rewrite_calls("llvm.log.f32", GenerateVF32Log, /*vector_width=*/1); + rewrite_calls(kLogV4F32SymbolName, GenerateVF32Log, /*vector_width=*/4); + rewrite_calls(kLogV8F32SymbolName, GenerateVF32Log, /*vector_width=*/8); } } // namespace runtime diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index f8441c3e345504616485c6b34b4302acd5cc23a3..a6f4273a5a70aab0bc88383283d2a55b1ecb1681 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -34,7 +34,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm::Type* index_type) { CHECK_NE(index_type, nullptr); - CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!shape_.IsTuple()); CHECK(!ShapeUtil::IsScalar(shape_)); llvm_ir::ForLoopNest loop_nest(loop_name, b_); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index ede7f433ca6b2cc5629115f800348be9dfb2b93b..6121d1ca9a5c785cedd947200d3e7e320aa06bc2 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -146,11 +146,9 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || - PotentiallyImplementedAsEigenDot(*instruction, - target_machine_features_) || (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || - ShapeUtil::IsTuple(instruction->shape())) { + instruction->shape().IsTuple()) { return 1; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index f0b65046c14ccec5336abf7c4d05d1d755f783bd..35ae62b42dfa768c6abd0508097d6b235b2ebf54 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -112,10 +112,10 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - token = token[] after-all() - infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token) + token0 = token[] after-all() + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token0) infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 - ROOT outfeed0 = token[] outfeed(infeed0.data, token) + ROOT outfeed0 = token[] outfeed(infeed0.data, token0) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index 2d9492eacfea34bec3b0f1115e171a5328b7cdc3..6f72ddadf94d4c5b9add2ee66e0f4ac9a8ae9099 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -69,8 +69,13 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( CHECK_EQ(params, nullptr); CHECK_GT(num_partitions, 1); CHECK_GT(num_partitioned_dims, 0); + CHECK_NE(function_ptr, nullptr); + CHECK_NE(partitions, nullptr); const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); + CHECK_NE(run_options, nullptr); + CHECK_NE(run_options->intra_op_thread_pool(), nullptr); + ComputeFunctionType function = reinterpret_cast(function_ptr); // Compute partition stride in 'partitions' array. diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 722aa3120ef4d8c957873ac58c361f19632dde1f..a0667d0d9d1cde246f4b74626859955beeec08b0 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -15,12 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include -#include #include -#include #include +#include #include -#include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/platform/dynamic_annotations.h" @@ -28,80 +26,14 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace { -using tensorflow::int16; using tensorflow::int32; using tensorflow::int64; -using tensorflow::int8; -using tensorflow::uint16; -using tensorflow::uint32; -using tensorflow::uint64; -using tensorflow::uint8; - -template -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements); -} - -// We would like a total order of floating point numbers so that the -// sort has a predictable behavior in the presence of NaNs. Rather -// than using floating point comparison, we use the following trick: -// If f is a float, and -// x = bit_cast(f); -// y = x < 0 ? 0x7FFFFFFF - x : x; -// then y is ordered as an int32 such that finite values have the -// obvious order, -0 is ordered before 0, and -NaN and NaN appear at -// the beginning and end of the ordering. -template -CastType Convert(KeyType value) { - CastType casted_value; - memcpy(&casted_value, &value, sizeof(CastType)); - if (casted_value < 0) { - return static_cast(std::numeric_limits::max()) - - casted_value; - } - return casted_value; -} - -template -bool LessThan(KeyType lhs, KeyType rhs) { - return Convert(lhs) < - Convert(rhs); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, rhs.first); - }); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, rhs.first); - }); -} - -template <> -void KeyValueSort(std::pair* row_to_sort, - int64 num_elements) { - std::stable_sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan( - Eigen::half_impl::half_to_float(lhs.first), - Eigen::half_impl::half_to_float(rhs.first)); - }); -} +} // namespace -template -void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, - int32 values_count, - int32* values_primitive_type_size_in_bytes) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( + int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes, + bool (*less_than)(char*, char*)) { // 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT // code, so msan can't tell they are initialized. TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, values_count * sizeof(char*)); @@ -121,8 +53,8 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, int64 num_iteration_elements = a * c; int64 sort_dimension_offset = c; - std::unique_ptr[]> row_to_sort( - new std::pair[sort_dimension_elements]); + std::unique_ptr indices(new int64[sort_dimension_elements]); + std::iota(indices.get(), indices.get() + sort_dimension_elements, 0); std::unique_ptr reordered_values( new std::string[sort_dimension_elements]); for (int64 index = 0; index < num_iteration_elements; ++index) { @@ -135,24 +67,22 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, int64 base_offset = index % sort_dimension_offset + (index - index % sort_dimension_offset) * sort_dimension_elements; - // TODO(b/26783907): We could define a custom iterator class that references - // all arrays. Then we could avoid the intermediate copy. However this - // would become more complicated, and it is not clear if the benefit is high - // enough. - for (int64 i = 0; i < sort_dimension_elements; ++i) { - row_to_sort[i] = - std::make_pair(keys[base_offset + i * sort_dimension_offset], i); - } - KeyValueSort(row_to_sort.get(), sort_dimension_elements); - for (int64 i = 0; i < sort_dimension_elements; ++i) { - keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first; - } - - // Reorder the values according to the order defined by the keys. + std::stable_sort( + indices.get(), indices.get() + sort_dimension_elements, + [&](int64 a, int64 b) { + int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + return less_than(values[0] + memory_index_lhs, + values[0] + memory_index_rhs); + }); + + // Reorder the values according to the order defined by 'indices'. for (int32 idx = 0; idx < values_count; ++idx) { for (int64 i = 0; i < sort_dimension_elements; ++i) { int64 memory_index = - (base_offset + row_to_sort[i].second * sort_dimension_offset) * + (base_offset + indices[i] * sort_dimension_offset) * values_primitive_type_size_in_bytes[idx]; reordered_values[i] = @@ -168,88 +98,3 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, } } } -} // namespace - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED( - bool* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8( - int8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8( - uint8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16( - int16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16( - uint16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16( - Eigen::half* keys, int64 a, int64 b, int64 c, char** values, - int32 values_count, int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32( - int32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32( - uint32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32( - float* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64( - int64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64( - uint64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} - -TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64( - double* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_count, - values_primitive_type_size_in_bytes); -} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h index 7821099386969e855ea1737cf53ef49c15c6e93b..5460af3485b94aaef1a5822a79e4fa325bcb67ea 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -21,76 +21,19 @@ limitations under the License. extern "C" { -// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b' -// dimension of 'keys' is sorted into ascending order. If 'values_count' is <= -// 0, 'values' and 'values_primitive_type_size_in_bytes' can be nullptr. -// If 'values_count' > 0, they contain exactly 'values_count' many elements. -// Each element of 'values' also represents a 3-dimensional shape with -// dimensions [a, b, c], and the size of the primitive type of the i-th shape -// has exactly 'values_primitive_type_size_in_bytes[i]' bytes. The elements in -// each 'values' shape are reordered in such a way that if the element at index -// 'i' in 'keys' was moved to index 'j', the element at index 'i' in a 'values' -// shape is also moved to index 'j' (which means that the same elements -// correspond to each other as before). -extern void __xla_cpu_runtime_KeyValueSortPRED( - bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, +// Each entry in 'values' represents a 3-dimensional shape with dimensions +// [a, b, c]. The 'b' dimension of the first shape is sorted into ascending +// order according to the results of comparisons using the provided 'less_than' +// function. 'values_count' must be > 0 and specifies the number of entries in +// 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive +// type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]' +// bytes. The elements in each 'values' shape are reordered in the same way +// according to the comparisons using the first shape. +extern void __xla_cpu_runtime_KeyValueSort( + tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS8( - tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU8( - tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS16( - tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU16( - tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF16( - Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS32( - tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU32( - tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF32( - float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortS64( - tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortU64( - tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); - -extern void __xla_cpu_runtime_KeyValueSortF64( - double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes); + tensorflow::int32* values_primitive_type_size_in_bytes, + bool (*less_than)(char*, char*)); } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index a71a85913cfef271bc2a226cb0cf2dd4204499a4..fe7e87a197b6cf571195537eaea2898659cd5e2e 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -23,12 +23,20 @@ limitations under the License. #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + using tensorflow::int32; using tensorflow::int64; namespace { -template +bool Is16BytesAligned(void* ptr) { + return reinterpret_cast(ptr) % 16 == 0; +} + +template void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = @@ -46,11 +54,11 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, std::swap(rhs_rows, rhs_cols); } - const Eigen::TensorMap, Eigen::Aligned> A( - lhs, lhs_rows, lhs_cols); - const Eigen::TensorMap, Eigen::Aligned> B( - rhs, rhs_rows, rhs_cols); - Eigen::TensorMap, Eigen::Aligned> C(out, m, n); + const Eigen::TensorMap, Alignment> A(lhs, lhs_rows, + lhs_cols); + const Eigen::TensorMap, Alignment> B(rhs, rhs_rows, + rhs_cols); + Eigen::TensorMap, Alignment> C(out, m, n); typedef typename Eigen::Tensor::DimensionPair DimPair; int lhs_contract_dim = transpose_lhs ? 0 : 1; @@ -65,14 +73,24 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, } template -void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +void MatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs, + int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + bool all_buffers_16b_aligned = + Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); + + if (!all_buffers_16b_aligned) { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); + return; + } + if (m == 1 || n == 1) { // Despite being single threaded, this version of matrix * vector is faster. xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } } @@ -82,20 +100,20 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 16692e7f2e6145b2649b67987eef47916e958be2..1f7204e67a413efabd34cd7d88ced4c82ee7a5df 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -20,12 +20,20 @@ limitations under the License. #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + using tensorflow::int32; using tensorflow::int64; namespace { -template +bool Is16BytesAligned(void* ptr) { + return reinterpret_cast(ptr) % 16 == 0; +} + +template void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { int64 lhs_rows = m; @@ -40,11 +48,11 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, std::swap(rhs_rows, rhs_cols); } - const Eigen::TensorMap, Eigen::Aligned> A( - lhs, lhs_rows, lhs_cols); - const Eigen::TensorMap, Eigen::Aligned> B( - rhs, rhs_rows, rhs_cols); - Eigen::TensorMap, Eigen::Aligned> C(out, m, n); + const Eigen::TensorMap, Alignment> A(lhs, lhs_rows, + lhs_cols); + const Eigen::TensorMap, Alignment> B(rhs, rhs_rows, + rhs_cols); + Eigen::TensorMap, Alignment> C(out, m, n); typedef typename Eigen::Tensor::DimensionPair DimPair; int lhs_contract_dim = transpose_lhs ? 0 : 1; @@ -59,14 +67,22 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, } template -void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, - int64 m, int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, + T* rhs, int64 m, int64 n, int64 k, + int32 transpose_lhs, int32 transpose_rhs) { + bool all_buffers_16b_aligned = + Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); + + if (!all_buffers_16b_aligned) { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); + } + if (m == 1 || n == 1) { xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } } @@ -77,8 +93,8 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, + n, k, transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void @@ -87,8 +103,8 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void @@ -97,6 +113,6 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); + SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index efccadedf27181a4cddf4f1dc3610f7c6db1d821..9c2685674fbc133de1220caef81ac3b60a1c0f7c 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -116,13 +116,26 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, orc_jit_memory_mapper::GetInstance()); result.Resolver = symbol_resolver_; return result; + }, + /*NotifyLoaded=*/ + llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor(), + /*NotifyFinalized=*/ + [this](VModuleKeyT, const llvm::object::ObjectFile& object, + const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { + this->NotifyObjectFinalized(object, object_info); + }, + /*NotifyFreed=*/ + [this](VModuleKeyT, const llvm::object::ObjectFile& object) { + this->NotifyObjectFreed(object); }), compile_layer_(object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, optimize_for_size, enable_fast_math, disable_expensive_passes, std::move(pre_optimization_hook), - std::move(post_optimization_hook))) { + std::move(post_optimization_hook))), + gdb_jit_event_listener_( + llvm::JITEventListener::createGDBRegistrationListener()) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() << " features: " << target_machine_->getTargetFeatureString().str(); } @@ -139,7 +152,7 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { } if (func_addr == nullptr) { - VLOG(2) << "Unable to resolve runtime symbol: " << name; + LOG(ERROR) << "Unable to resolve runtime symbol: " << name; return nullptr; } llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), @@ -147,6 +160,20 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { return symbol_info; } +void SimpleOrcJIT::NotifyObjectFinalized( + const llvm::object::ObjectFile& object, + const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { + uint64_t key = static_cast( + reinterpret_cast(object.getData().data())); + gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info); +} + +void SimpleOrcJIT::NotifyObjectFreed(const llvm::object::ObjectFile& object) { + uint64_t key = static_cast( + reinterpret_cast(object.getData().data())); + gdb_jit_event_listener_->notifyFreeingObject(key); +} + SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule( std::unique_ptr module) { auto key = execution_session_.allocateVModule(); @@ -213,18 +240,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64); - REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort); registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); @@ -296,6 +312,9 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(sin, double (*)(double)); #ifdef __APPLE__ REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); + registry->Register("__sincosf_stret", + reinterpret_cast(__sincosf_stret)); + registry->Register("__sincos_stret", reinterpret_cast(__sincos_stret)); #else REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); #endif @@ -311,6 +330,13 @@ bool RegisterKnownJITSymbols() { registry->Register("memcpy", reinterpret_cast(memcpy)); registry->Register("memmove", reinterpret_cast(memmove)); registry->Register("memset", reinterpret_cast(memset)); + +#ifdef __APPLE__ + registry->Register("__bzero", reinterpret_cast(bzero)); + registry->Register("memset_pattern16", + reinterpret_cast(memset_pattern16)); +#endif + return true; } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 78406ba143570183aea09d79db3f9b708c21bf70..3307c2f93d796bbdcd49af7f68e9f6c388e402ca 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" +#include "llvm/ExecutionEngine/JITEventListener.h" #include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" @@ -99,6 +100,11 @@ class SimpleOrcJIT { private: llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); + void NotifyObjectFinalized( + const llvm::object::ObjectFile& object, + const llvm::RuntimeDyld::LoadedObjectInfo& object_info); + void NotifyObjectFreed(const llvm::object::ObjectFile& object); + std::vector module_keys_; std::unique_ptr target_machine_; const Disassembler disassembler_; @@ -107,6 +113,15 @@ class SimpleOrcJIT { std::shared_ptr symbol_resolver_; ObjLayerT object_layer_; CompileLayerT compile_layer_; + + // Non owning pointer to a JIT event listener that registers the JIT events + // with an attached GDB. + // + // Note: we get a pointer to this event listener using + // `createGDBRegistrationListener` which makes it look like we're supposed to + // free this, but the function is poorly named and really just returns a + // pointer to a static object. + llvm::JITEventListener* gdb_jit_event_listener_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index f8f5f392da8ab3348e63185aecf7b639daacaa42..8b7f843582b697058fe328fe69990122d868ada4 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -16,7 +16,6 @@ limitations under the License. // Tests that we call into Eigen for dot operations as needed. #include -#include #include #include "absl/strings/str_cat.h" @@ -102,10 +101,10 @@ std::vector GetDotTestCases() { return result; } -INSTANTIATE_TEST_CASE_P(CpuEigenDotOperationTestInstantiation, - CpuEigenDotOperationTest, - ::testing::ValuesIn(GetDotTestCases()), - DotTestSpecToString); +INSTANTIATE_TEST_SUITE_P(CpuEigenDotOperationTestInstantiation, + CpuEigenDotOperationTest, + ::testing::ValuesIn(GetDotTestCases()), + DotTestSpecToString); } // namespace } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index 5cc6d01c0f15d4209cbc1fb259a0078fb9957f6e..f0f897e9635600b22e0c389ba056899e4d6ab3d4 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -48,7 +48,7 @@ class InfeedTest : public ClientLibraryTestBase { ASSERT_IS_OK(client_->TransferToInfeed(literal)); XlaBuilder builder(TestName()); Infeed(&builder, literal.shape()); - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { // TODO(b/30609564): Use ComputeAndCompareLiteral instead. ComputeAndCompareTuple(&builder, literal, {}); } else { diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 9b10c49f4f547edfb2164f98c49cceb031148bdc..9078b8fd1ff6cb0ddac89d5fcd13a9ccfae07763 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include #include +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" @@ -59,8 +59,9 @@ class CpuUnaryIntrinsicTest string features{spec.features.data(), spec.features.size()}; if (!features.empty()) { - std::replace_if(features.begin(), features.end(), - [](char c) { return c != '_' && !isalnum(c); }, '_'); + std::replace_if( + features.begin(), features.end(), + [](char c) { return c != '_' && !absl::ascii_isalnum(c); }, '_'); } else { features = ""; } @@ -140,10 +141,10 @@ IntrinsicTestSpec CpuUnaryIntrinsicTestCases[] = { HloOpcode::kLog, kTriple_android_arm, "", R"(CHECK: fadd fast <4 x float> )"}}; -INSTANTIATE_TEST_CASE_P(CpuUnaryIntrinsicTestInstantiation, - CpuUnaryIntrinsicTest, - ::testing::ValuesIn(CpuUnaryIntrinsicTestCases), - CpuUnaryIntrinsicTest::Name); +INSTANTIATE_TEST_SUITE_P(CpuUnaryIntrinsicTestInstantiation, + CpuUnaryIntrinsicTest, + ::testing::ValuesIn(CpuUnaryIntrinsicTestCases), + CpuUnaryIntrinsicTest::Name); } // namespace } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index fa0e09ff6b5694c0e97963b83c6e541b858a1376..0584c0484f810a03ccccd522163f54535440ef8b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -31,29 +31,27 @@ HloModule RepeatedConstants while_body { arg_body = f32[2,3,2] parameter(0) ROOT const = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) } while_cond { arg_cond = f32[2,3,2] parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { param = f32[2,3,2] parameter(0) const_a = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - token = token[] after-all() - out0 = token[] outfeed(f32[2,3,2] const_a, token[] token) - ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token) + token0 = token[] after-all() + out0 = token[] outfeed(f32[2,3,2] const_a, token[] token0) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token0) } )"; @@ -82,24 +80,24 @@ HloModule RepeatedConstants while_body { arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant(({ { 1 }, { 2 } }, {2} )) } while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { param = f32[2,3,2] parameter(0) - const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant(( { { 1 }, { 2 } }, {2} )) const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - token = token[] after-all() - out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token) + token0 = token[] after-all() + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token0) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token0) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index a7702c2aeeaff8a46a2c4f2785ccb873ea2c08e5..030bd41c2fc73eac41fe43c1acdf862d5dc97f98 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -75,8 +75,9 @@ TEST_F(CpuNoAliasTest, Concat) { // the buffers in the HLO module. We'll inspect these loads to ensure that // they have the expected alias information. llvm::Module ir_module("test", context); - llvm::Function* func = llvm::cast( - ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context))); + llvm::Function* func = llvm::dyn_cast( + ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context)) + .getCallee()); llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "body", func); llvm::IRBuilder<> b(bb); auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index e2c7af541eede5265f274c72f55305549f059839..aab7f0b393881642437f1891256bd138823a3b87 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -28,12 +28,11 @@ HloModule Outfeed ENTRY main { const_a = f32[2,3,2] constant( - f32[2,3,2] {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - token = token[] after-all() - outfeed = token[] outfeed(f32[2,3,2] const_a, token) + token0 = token[] after-all() + outfeed = token[] outfeed(f32[2,3,2] const_a, token0) ROOT root = () tuple() } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb6c44b70ab34d0a294880b5de4fe0b3ba5e19e5 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc @@ -0,0 +1,1014 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" + +#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +namespace cpu { +namespace { + +using tensorflow::int64; + +// Provides tiled access to an in-memory rank 2 array. +class MemoryTile { + public: + // Constructs a MemoryTile that can operate on tiles consisting of + // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at + // `major_dim_offset` in the major dimension. The tile size along the minor + // dimension is the vector size, and that is implicitly determined by `vsl`. + MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b, + llvm::Value* matrix, int64 matrix_size_along_minor_dim, + llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) + : vsl_(vsl), b_(b) { + pointers_.reserve(tile_size_along_major_dim); + for (int64 i = 0; i < tile_size_along_major_dim; i++) { + llvm::Value* total_offset = + b->CreateMul(b->getInt64(matrix_size_along_minor_dim), + b->CreateAdd(b->getInt64(i), major_dim_offset)); + pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); + } + } + + // Load a tile consisting of `tile_size_along_major_dim` vectors from position + // {major: `major_dim_offset`, minor: `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector LoadTile(llvm::Value* minor_dim_offset) const { + std::vector result; + result.reserve(pointers_.size()); + for (const auto& pointer : pointers_) { + result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); + } + return result; + } + + // Stores `tile` to position {major: `major_dim_offset`, minor: + // `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + void StoreTile(absl::Span tile, + llvm::Value* minor_dim_offset) const { + CHECK_EQ(tile.size(), pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); + } + } + + // Loads a tile of size [`tile_size_along_major_dim`, + // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, + // minor: `minor_dim_offset`} and then broadcasts each element into a vector + // of size vsl_.vector_size(). The (i,j)'th element of the return value is + // the (i,j)'th element in the tile broadcasted into an LLVM vector. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector> LoadBroadcastTile( + llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { + std::vector> result; + result.resize(pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + for (int64 j = 0; j < tile_size_along_middle_dim; j++) { + result[i].push_back(vsl_->LoadBroadcast( + pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j)))); + } + } + return result; + } + + private: + VectorSupportLibrary* vsl_; + llvm::IRBuilder<>* b_; + std::vector pointers_; +}; + +// The base class for the classes representing the GEMV emitter configurations. +// +// The IR emitted (modulo the LLVM values representing the input and output +// buffers) by the row major and column major GEMV emitters should be a function +// of their configuration. This is important because their configuration is +// used as a key to cache the generated IR. +class GemvConfig { + public: + // Mixin for convenience. + template + struct User { + public: + PrimitiveType scalar_type() const { + return derived().config().scalar_type(); + } + int64 tile_rows() const { return derived().config().tile_rows(); } + int64 tile_cols() const { return derived().config().tile_cols(); } + int64 m() const { return derived().config().m(); } + int64 k() const { return derived().config().k(); } + int64 has_addend() const { return derived().config().has_addend(); } + + private: + const T& derived() const { return *static_cast(this); } + }; + + PrimitiveType scalar_type() const { return scalar_type_; } + int64 tile_rows() const { return tile_rows_; } + int64 tile_cols() const { return tile_cols_; } + int64 m() const { return m_; } + int64 k() const { return k_; } + bool has_addend() const { return has_addend_; } + + string GetCacheKey() const { + return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", + tile_rows(), "_", tile_cols(), "_", m(), "_", k(), + has_addend() ? "_with_addend" : ""); + } + + protected: + explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, bool has_addend) + : name_(std::move(name)), + scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + has_addend_(has_addend) {} + + private: + string name_; + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + bool has_addend_; +}; + +// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +--+--+--+--+ +// |M00|M10|M20|M30| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M03|M13|M23|M33| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// +// (Legend: rows are horizontal and columns are vertical; and each column is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is from the column major left matrix. +// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] +// vector loaded from the RHS vector. +// +// As we iterate through the column dimension, we compute the change to the +// result vector by an elementwise multiplication between the two tiles above +// followed by a reduction along the major dimension: +// +// +-----------------------------------+ +// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | +// +-----------------------------------+ +// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | +// Result[R:R+4] += +-----------------------------------+ +// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | +// +-----------------------------------+ +// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | +// +-----------------------------------+ +// +// Where R is the starting row for the tile. +// +// We have an inner epilogue loop to deal with the "C" submatrix and an outer +// epilogue loop to deal with the B,D submarix. +// +// TODO(sanjoy): We should investigate if using gather loads and scatter stores +// can be used here have the same inner loop for both column-major and row-major +// matrix-vector products. +class ColumnMajorMatrixVectorProductEmitter + : public GemvConfig::User { + public: + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"col_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, + llvm::IRBuilder<>* b) + : config_(config), + lhs_(lhs), + rhs_(rhs), + addend_(addend), + result_(result), + b_(b), + ksl_(b_), + vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") { + CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); + CHECK(!has_addend() || addend != nullptr); + } + + void Emit(); + + const Config& config() const { return config_; } + + private: + void EmitOuterLoopBody(llvm::Value* column, int64 column_count, + bool is_first_column); + + MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m(), + /*major_dim_offset=*/column_start, + /*tile_size_along_major_dim=*/column_count); + } + + // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous + // sequence of `count` values, each one broadcasted to the vector width. + std::vector LoadRhsTile(llvm::Value* offset, int64 count) { + llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); + std::vector result; + result.reserve(count); + for (int64 i = 0; i < count; i++) { + result.push_back(vsl_.LoadBroadcast(base_pointer, i)); + } + return result; + } + + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, + const std::vector& rhs_tile, + int64 columns, bool is_first_column); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, + bool is_first_tiled_column); + + Config config_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* addend_; + llvm::Value* result_; + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( + llvm::Value* column, int64 column_count, bool is_first_column) { + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, + /*column_count=*/column_count); + + std::vector rhs_tile = + LoadRhsTile(column, /*count=*/column_count); + EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, + /*columns=*/column_count, is_first_column); + EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); +} + +void ColumnMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 column_remainder = k() % tile_cols(); + int64 column_limit = k() - column_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); + + if (column_remainder != 0) { + EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, + column_limit == 0); + } +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, + int64 columns, bool is_first_column) { + int64 row_limit = m() - (m() % tile_rows()); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows(), [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { + int64 row_start = m() - (m() % tile_rows()); + if (row_start == m()) { + return; + } + + llvm::Value* columns_llvm = b_->getInt64(columns); + + // for (col = current_tile_col; col < (columns + current_tile_col); col++) + // for (row = row_start, row < m_; row++) { + // result[row] += lhs[row, col] * rhs[col] + // // Also take into account that if col is 0 then result[row] is not + // // initialized. + // } + + ksl_.For( + "dot.inner.epilg.outer", /*start=*/current_tile_col, + /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), + /*step=*/1, /*peel_first_iteration=*/false, + [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { + llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); + llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For( + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), + /*step=*/1, [&](llvm::Value* scalar_row) { + llvm::Value* product = vsl_.Mul( + vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); + llvm::Value* setting_result_first_time = b_->CreateAnd( + is_first_scalar_col, b_->getInt1(is_first_tiled_column)); + ksl_.If( + setting_result_first_time, + /*true_block_generator=*/ + [&]() { + if (addend_) { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), + product), + result_, scalar_row); + } else { + vsl_.StoreScalar(product, result_, scalar_row); + } + }, + /*false_block_generator=*/ + [&]() { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), + result_, scalar_row); + }); + }); + }); +} + +// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +// |M00|M10|M20|M30| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| +// +---+---+---+---+ +// |M03|M13|M23|M33| +// +---+---+---+---+ +// +// (Legend: rows are horizontal and columns are vertical; and each row is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is loaded from the row major left matrix. +// b. The right vector is loaded from the RHS vector. +// +// We keep 4 vector accumulators accumulating the following four vector +// expressions as we iterate over the row dimension: +// +// +------+------+------+------+ +// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) +// +------+------+------+------+ +// +// In the end we do a horizontal reduction over these 4 vector accumulators to +// get 4 values in the result vector. +// +// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer +// epilogue loop to deal with the C,D submatrix. +class RowMajorMatrixVectorProductEmitter + : public GemvConfig::User { + public: + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"row_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b) + : config_(config), + lhs_(lhs), + rhs_(rhs), + addend_(addend), + result_(result), + b_(b), + ksl_(b_), + vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") { + CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); + CHECK(!has_addend() || addend != nullptr); + } + + void Emit(); + + const Config& config() const { return config_; } + + private: + MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k(), + /*major_dim_offset=*/row_start, + /*tile_size_along_major_dim=*/row_count); + } + + void EmitOuterLoopBody(llvm::Value* row, int64 row_count); + + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, + std::vector* vector_accumulators); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators); + + Config config_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* addend_; + llvm::Value* result_; + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, + int64 row_count) { + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, + /*row_count=*/row_count); + std::vector vector_accumulators; + std::vector scalar_accumulators; + for (int i = 0; i < row_count; i++) { + vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); + scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); + } + EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, + &vector_accumulators); + EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, + &scalar_accumulators); + + std::vector accumulator_values; + std::transform( + vector_accumulators.begin(), vector_accumulators.end(), + std::back_inserter(accumulator_values), + [](const VectorVariable& vector_var) { return vector_var.Get(); }); + + std::vector horizontal_sums; + if (row_count == vsl_.vector_size()) { + if (addend_) { + horizontal_sums = vsl_.ComputeHorizontalSums( + std::move(accumulator_values), vsl_.LoadVector(addend_, row)); + } else { + horizontal_sums = + vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + } else { + horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + + for (int i = 0; i < row_count; i++) { + llvm::Value* result_value = + vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); + llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row); + if (addend_ && row_count != vsl_.vector_size()) { + result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); + } + vsl_.StoreScalar(result_value, result_, offset); + } +} + +void RowMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 row_remainder = m() % tile_rows(); + int64 row_limit = m() - row_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); + + if (row_remainder != 0) { + EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); + } +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + MemoryTile* lhs_memory_tile, int64 rows, + std::vector* vector_accumulators) { + int64 column_limit = k() - (k() % tile_cols()); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set( + vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators) { + int64 column_start = k() - (k() % tile_cols()); + if (column_start == k()) { + return; + } + + for (int r = 0; r < rows; r++) { + llvm::Value* total_offset = b_->CreateMul( + b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); + } +} + +// This class implements a tiled matrix multiplication algorithm, intended for +// multiplying small matrices that don't need cache tiling. +// +// In the future this can be used as the innermost GEBP loop in a GEMM kernel as +// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of +// high-performance matrix multiplication." ACM Transactions on Mathematical +// Software (TOMS) 34.3 (2008): 12.". +// +// This only supports canonical dot operations (i.e. where the lhs contraction +// dimension is 1 and the rhs contraction dimension is 0) over row major +// matrices. +class TiledSmallGemmEmitter { + public: + // Describe the dimensions of the kernel. + class Dimensions { + public: + explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + + int64 m() const { return m_; } + int64 k() const { return k_; } + int64 n() const { return n_; } + + string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } + + private: + const int64 m_; + const int64 k_; + const int64 n_; + }; + + // Represents the configuration of the emitter. The LLVM IR emitted by the + // emitter, modulo the LLVM values holding the input and output buffers, must + // be a function of the instance of `Config` passed to it. + // + // `dims` holds the matrix multiplication dimensions. + // + // `max_vectorization_width` is the maximum vector width (i.e. the width of + // the largest vector register we will use). This can be larger than the + // largest vector register supported by the machine -- LLVM will legalize + // these large vector widths into legally sized vectors. + // + // `max_vector_count` is the maximum number of vectors of size + // `max_vectorization_width` that we will attempt to process at once. + // + // `min_vectorization_width` is the smallest vector width the emitter will use + // -- below that it will devolve to using a scalar loop. + // + // The innermost reduction loop executes the matrix multiply in tiles of size + // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, + // ] in the RHS. + class Config { + public: + explicit Config(PrimitiveType scalar_type, Dimensions dims, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k) + : scalar_type_(scalar_type), + dims_(dims), + max_vectorization_width_(max_vectorization_width), + max_vector_count_(max_vector_count), + min_vectorization_width_(min_vectorization_width), + tile_size_m_(tile_size_m), + tile_size_k_(tile_size_k) {} + + string GetCacheKey() const { + return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", + dims().ToString(), "_", max_vectorization_width(), + "_", min_vectorization_width(), "_", tile_size_m(), + "_", tile_size_k()); + } + + PrimitiveType scalar_type() const { return scalar_type_; } + Dimensions dims() const { return dims_; } + int64 max_vectorization_width() const { return max_vectorization_width_; } + int64 max_vector_count() const { return max_vector_count_; } + int64 min_vectorization_width() const { return min_vectorization_width_; } + + int64 tile_size_m() const { return tile_size_m_; } + int64 tile_size_k() const { return tile_size_k_; } + + private: + PrimitiveType scalar_type_; + Dimensions dims_; + int64 max_vectorization_width_; + int64 max_vector_count_; + int64 min_vectorization_width_; + int64 tile_size_m_; + int64 tile_size_k_; + }; + + // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* b) + : lhs_(lhs), + rhs_(rhs), + result_(result), + config_(config), + b_(b), + ksl_(b_) { + CHECK(max_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(max_vectorization_width()))); + CHECK_GT(max_vector_count(), 0); + CHECK(min_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(min_vectorization_width()))); + CHECK_GE(max_vectorization_width(), min_vectorization_width()); + CHECK_GT(tile_size_k(), 0); + } + + void Emit(); + + private: + // The HandleResiduesOnX helpers split the iteration space for dimension X + // into a multiple of the tile size on dimension X and an epilogue. These + // helpers ultimately call into `EmitTiledGemm` for emitting the + // tiled GEMM kernel. + + void HandleResiduesOnN(); + void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end); + + // This emits a tiled GEMM kernel. For a detailed description see the comment + // on the implementation. + void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); + + llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); } + + Config config() const { return config_; } + Dimensions dims() const { return config().dims(); } + + int64 max_vectorization_width() const { + return config().max_vectorization_width(); + } + int64 max_vector_count() const { return config().max_vector_count(); } + int64 min_vectorization_width() const { + return config().min_vectorization_width(); + } + int64 tile_size_m() const { return config().tile_size_m(); } + int64 tile_size_k() const { return config().tile_size_k(); } + PrimitiveType scalar_type() const { return config().scalar_type(); } + + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + Config config_; + + llvm::IRBuilder<>* b_; + KernelSupportLibrary ksl_; +}; + +void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } + +void TiledSmallGemmEmitter::HandleResiduesOnN() { + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. + + int64 current_vectorization_width = + max_vector_count() * max_vectorization_width(); + int64 current_vector_count = max_vector_count(); + + int64 n_start = 0; + while (n_start != dims().n() && + current_vectorization_width >= min_vectorization_width()) { + int64 n_end = dims().n() - (dims().n() % current_vectorization_width); + if (n_start != n_end) { + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, + "gemm"); + HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); + n_start = n_end; + } + if (current_vector_count == 1) { + current_vectorization_width /= 2; + } else { + current_vector_count--; + current_vectorization_width = + current_vector_count * max_vectorization_width(); + } + } + + if (n_start != dims().n()) { + VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); + ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); + HandleResiduesOnK(&vsl, n_i, n_i_next); + }); + } +} + +void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { + int64 k_start = 0; + int64 k_end = dims().k() - (dims().k() % tile_size_k()); + if (k_end != k_start) { + HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), + n_start, n_end); + k_start = k_end; + } + + if (k_start != dims().k()) { + HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), + GetInt64(dims().k()), n_start, n_end); + } +} + +void TiledSmallGemmEmitter::HandleResiduesOnM( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { + const int64 m_end = dims().m() - dims().m() % tile_size_m(); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), + GetInt64(0), GetInt64(m_end)); + + if (m_end != dims().m()) { + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); + } +} + +// The loop structure is: +// +// Iterate over dimension M as m: +// Iterate over dimension N as n: +// Iterate over dimension K as k: +// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) +// +// I.e. a just a tiled version of a "naive" GEMM. +// +// The tiling scheme is as follows: +// +// Let the LHS be: +// +// +----+----+----+ +// | a0 | b0 | c0 | . +// +----+----+----+ . +// | a1 | b1 | c1 | . +// +----+----+----+ +// .. .. +// +// and the RHS be: +// +// +----+----+----+----+ +// | p0 | p1 | p2 | p3 | . +// +----+----+----+----+ . +// | q0 | q1 | q2 | q3 | . +// +----+----+----+----+ +// | r0 | r1 | r2 | r3 | . +// +----+----+----+----+ . +// ...... ...... +// +// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted +// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] +// matrix that we can increment the result matrix by. +// +// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank +// 3 array, L, of dimension [2,3,4]: +// +// L[0,_,_] * L[1,_,_] +// * +// +----+----+----+----+ * +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | +// +----+----+----+----+ * +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | +// +----+----+----+----+ * +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | +// +----+----+----+----+ * +----+----+----+----+ +// +// +// Then we FMA L[0,_,_] with the RHS to get the first row of the result and +// L[1,_,_] with the RHS to get the second row of the result. For example, +// L[0,_,_] is computed as: +// +// +----+----+----+----+ +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | +// +----+----+----+----+ +----+----+----+----+ +// +// to get: +// +// +-------------------+-------------------+-------------------+--------- +// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... +// +-------------------+-------------------+-------------------+--------- +void TiledSmallGemmEmitter::EmitTiledGemm( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { + ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile(vsl, b_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + ksl_.For( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + TileVariable result_tile_var(vsl, result_memory_tile.LoadTile(n_i)); + ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, + tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); + std::vector rhs_tile = rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = result_tile_var.Get(); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } + } + result_tile_var.Set(result_tile); + }); + + result_memory_tile.StoreTile(result_tile_var.Get(), n_i); + }); + }); +} + +} // namespace + +void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + RowMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, + /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, + rhs, addend, result, + [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result) { + RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, + result, b); + emitter.Emit(); + }); +} + +void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + ColumnMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, + /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, + rhs, addend, result, + [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result) { + ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, + result, b); + emitter.Emit(); + }); +} + +void EmitSmallGemm(PrimitiveType scalar_type, int64 m, int64 k, int64 n, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k, llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size) { + TiledSmallGemmEmitter::Config config( + /*scalar_type=*/scalar_type, + TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + /*max_vectorization_width=*/max_vectorization_width, + /*max_vector_count=*/max_vector_count, + /*min_vectorization_width=*/min_vectorization_width, + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, + rhs, result, + [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result) { + TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, + /*rhs=*/rhs, + /*result=*/result, b); + small_gemm_emitter.Emit(); + }); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..0a82326cc3704bce8c122261383249c60eda1f3a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace cpu { + +// These routines emit LLVM IR implementing tiled GEMM and GEMV routines. + +void EmitRowMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, + tensorflow::int64 tile_cols, tensorflow::int64 m, + tensorflow::int64 k, llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, llvm::Value* result, + llvm::IRBuilder<>* b, bool enable_fast_math, + bool optimize_for_size); + +void EmitColumnMajorGemv(PrimitiveType scalar_type, tensorflow::int64 tile_rows, + tensorflow::int64 tile_cols, tensorflow::int64 m, + tensorflow::int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b, + bool enable_fast_math, bool optimize_for_size); + +void EmitSmallGemm(PrimitiveType scalar_type, tensorflow::int64 m, + tensorflow::int64 k, tensorflow::int64 n, + tensorflow::int64 max_vectorization_width, + tensorflow::int64 max_vector_count, + tensorflow::int64 min_vectorization_width, + tensorflow::int64 tile_size_m, tensorflow::int64 tile_size_k, + llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* b, bool enable_fast_math, + bool optimize_for_size); + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TILED_DOT_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 5690d2be2fe3e21c96b51a5226e0b29148217fd1..c444fd7d4aa88fa21b1aa2b2f058bd689b234b15 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -114,6 +114,9 @@ class VectorSupportLibrary { // raison d'etre) less cluttered. llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs)); + } llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) { diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index e84bf00153aa28df29d8df486b92654feab4afbf..c54e954c222ff7ca9c0739ec8a55b9d79b74a437 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -105,9 +105,10 @@ class DfsHloVisitorBase { } virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; - virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; + virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; + virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 80ea5be298aea44a0f424398da74c4e478f10346..33b2cc3fb098ec0d92f68756526fcc4a761d7149 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -91,7 +91,7 @@ class DfsHloVisitorWithDefaultBase Status HandleFft(HloInstructionPtr fft) override { return DefaultAction(fft); } - Status HandleCrossReplicaSum(HloInstructionPtr crs) override { + Status HandleAllReduce(HloInstructionPtr crs) override { return DefaultAction(crs); } Status HandleAllToAll(HloInstructionPtr hlo) override { @@ -100,6 +100,9 @@ class DfsHloVisitorWithDefaultBase Status HandleCollectivePermute(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandleReplicaId(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc index 825e1436f0ec6d49b555e5e3e9c2c7a19fb7b062..70173d43d79e931b75f131ad380ad98359cc78b8 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -73,15 +73,14 @@ ENTRY TestComputation { abs = f32[] abs(arg) add = f32[] add(arg, gte) broadcast = f32[42] broadcast(add), dimensions={} - slice = f32[0] slice(broadcast), slice={[1:2]} + slice = f32[1] slice(broadcast), slice={[1:2]} copy = f32[] copy(arg) eq = pred[] equal-to(arg, gte) neg = f32[] negate(arg) ROOT convert = f64[] convert(f32[] arg) })"; std::unique_ptr module = - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()) - .ConsumeValueOrDie(); + ParseAndReturnVerifiedModule(hlo_string).ConsumeValueOrDie(); ElementwiseTestVisitor visitor; TF_EXPECT_OK(module->entry_computation()->Accept(&visitor)); } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index b2ba2617902104bfea06713332fa1c2aedea536d..e8bc6d05716a2ef02e0280e86c7df4ac22fe78c4 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -156,29 +158,192 @@ Status DecomposeBatchDot(HloInstruction* dot) { return computation->ReplaceInstruction(dot, new_dot); } +// Convert a dot into a canonical form where non-contracting and contracting +// dimensions are reshaped together and batch dimensions are the most major +// dimensions. The requires transposing and reshapes the lhs and rhs and +// reshaping the output batch to the original shape. +Status CanonicalizeDot(HloInstruction* original_dot) { + auto computation = original_dot->parent(); + const auto& original_dnums = original_dot->dot_dimension_numbers(); + const int64 num_batch_dims = original_dnums.lhs_batch_dimensions_size(); + const int64 num_contracting_dims = + original_dnums.lhs_contracting_dimensions_size(); + + const auto& lhs_shape = original_dot->operand(0)->shape(); + const int64 lhs_rank = lhs_shape.rank(); + const int64 num_lhs_non_contracting_dims = + lhs_rank - num_batch_dims - num_contracting_dims; + + std::vector lhs_non_contracting_dims; + lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims); + int64 lhs_contracting_size = 1; + int64 lhs_non_contracting_size = 1; + std::vector batch_dim_sizes; + batch_dim_sizes.reserve(num_batch_dims); + for (int64 i = 0; i < lhs_rank; ++i) { + if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) { + lhs_contracting_size *= lhs_shape.dimensions(i); + } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(), + i)) { + batch_dim_sizes.push_back(lhs_shape.dimensions(i)); + } else { + lhs_non_contracting_dims.push_back(i); + lhs_non_contracting_size *= lhs_shape.dimensions(i); + } + } + // The canonical form of the lhs is + // [BatchDims, NonContractingDims, ContractingsDims] + std::vector lhs_transpose; + lhs_transpose.reserve(lhs_rank); + lhs_transpose.insert(lhs_transpose.end(), + original_dnums.lhs_batch_dimensions().begin(), + original_dnums.lhs_batch_dimensions().end()); + lhs_transpose.insert(lhs_transpose.end(), lhs_non_contracting_dims.begin(), + lhs_non_contracting_dims.end()); + lhs_transpose.insert(lhs_transpose.end(), + original_dnums.lhs_contracting_dimensions().begin(), + original_dnums.lhs_contracting_dimensions().end()); + HloInstruction* transposed_lhs = + computation->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose), + lhs_shape), + original_dot->mutable_operand(0), lhs_transpose)); + std::vector lhs_reshape_dims = batch_dim_sizes; + lhs_reshape_dims.push_back(lhs_non_contracting_size); + lhs_reshape_dims.push_back(lhs_contracting_size); + // Reshape the contracting and non-contracting dimensions together. + HloInstruction* reshaped_lhs = + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims), + transposed_lhs)); + + const auto& rhs_shape = original_dot->operand(1)->shape(); + const int64 rhs_rank = rhs_shape.rank(); + const int64 num_rhs_non_contracting_dims = + rhs_rank - num_batch_dims - num_contracting_dims; + std::vector rhs_non_contracting_dims; + rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims); + int64 rhs_non_contracting_size = 1; + int64 rhs_contracting_size = 1; + for (int64 i = 0; i < rhs_rank; ++i) { + if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) { + rhs_contracting_size *= rhs_shape.dimensions(i); + } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(), + i)) { + rhs_non_contracting_dims.push_back(i); + rhs_non_contracting_size *= rhs_shape.dimensions(i); + } + } + + // The canonical form of the rhs is + // [BatchDims, ContractingsDims, NonContractingDims] + std::vector rhs_transpose; + rhs_transpose.reserve(rhs_rank); + rhs_transpose.insert(rhs_transpose.end(), + original_dnums.rhs_batch_dimensions().begin(), + original_dnums.rhs_batch_dimensions().end()); + rhs_transpose.insert(rhs_transpose.end(), + original_dnums.rhs_contracting_dimensions().begin(), + original_dnums.rhs_contracting_dimensions().end()); + rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(), + rhs_non_contracting_dims.end()); + HloInstruction* transposed_rhs = + computation->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose), + rhs_shape), + original_dot->mutable_operand(1), rhs_transpose)); + + std::vector rhs_reshape_dims = batch_dim_sizes; + rhs_reshape_dims.push_back(rhs_contracting_size); + rhs_reshape_dims.push_back(rhs_non_contracting_size); + // Reshape the contracting and non-contracting dimensions together. + HloInstruction* reshaped_rhs = + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims), + transposed_rhs)); + + std::vector dot_dims = batch_dim_sizes; + dot_dims.push_back(lhs_non_contracting_size); + dot_dims.push_back(rhs_non_contracting_size); + + DotDimensionNumbers dot_dnums; + for (int64 i = 0; i < num_batch_dims; ++i) { + dot_dnums.add_lhs_batch_dimensions(i); + dot_dnums.add_rhs_batch_dimensions(i); + } + dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1); + dot_dnums.add_rhs_contracting_dimensions(num_batch_dims); + + HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims), + reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config())); + + return computation->ReplaceInstruction( + original_dot, computation->AddInstruction(HloInstruction::CreateReshape( + original_dot->shape(), dot))); +} + } // namespace StatusOr DotDecomposer::Run(HloModule* module) { XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString()); - // Gather all batch Dot operations. - std::vector batch_dots; + // Gather all Non-canonical Dot operations. + std::vector non_canonical_dots; for (auto* computation : module->MakeNonfusionComputations()) { for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kDot) { continue; } const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) { - batch_dots.push_back(instruction); + // A dot it not canonical if there are more than one contracting + // dimension. + if (dnums.lhs_contracting_dimensions_size() > 1) { + non_canonical_dots.push_back(instruction); + continue; + } + if (dnums.lhs_batch_dimensions().empty() && + dnums.lhs_contracting_dimensions().empty()) { + non_canonical_dots.push_back(instruction); + continue; + } + if (dnums.lhs_batch_dimensions().empty()) { + continue; + } + std::vector canonical_batch_dims( + dnums.lhs_batch_dimensions_size()); + absl::c_iota(canonical_batch_dims, 0); + if (!absl::c_equal(dnums.lhs_batch_dimensions(), canonical_batch_dims) || + !absl::c_equal(dnums.rhs_batch_dimensions(), canonical_batch_dims)) { + non_canonical_dots.push_back(instruction); } } } - // Decompose each batch Dot in 'batch_dots'. bool changed = false; - for (auto* dot : batch_dots) { - TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + for (auto* dot : non_canonical_dots) { + TF_RETURN_IF_ERROR(CanonicalizeDot(dot)); changed = true; } + + if (decompose_batch_dot_) { + std::vector batch_dots; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); + if (!dnums.lhs_batch_dimensions().empty()) { + batch_dots.push_back(instruction); + } + } + } + // Decompose each batch Dot in 'batch_dots'. + + for (auto* dot : batch_dots) { + TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + changed = true; + } + } XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc new file mode 100644 index 0000000000000000000000000000000000000000..de3b508064bfadd88396f050142e682de2294434 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -0,0 +1,616 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/while_util.h" + +namespace xla { + +namespace { +bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { + return window_dimension.size() == 1 && window_dimension.stride() == 1 && + window_dimension.padding_low() == 0 && + window_dimension.padding_high() == 0 && + window_dimension.window_dilation() == 1 && + window_dimension.base_dilation() == 1; +} +} // namespace + +class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { + public: + explicit DynamicDimensionInferenceVisitor( + const DynamicParameterBinding& param_bindings, + DynamicDimensionInference* parent) + : param_bindings_(param_bindings), parent_(parent) {} + + Status DefaultAction(HloInstruction* hlo) override; + + static Status Run(HloComputation* computation, + const DynamicParameterBinding& param_bindings, + DynamicDimensionInference* parent) { + DynamicDimensionInferenceVisitor visitor(param_bindings, parent); + return computation->Accept(&visitor); + } + + Status HandleParameter(HloInstruction* hlo) override; + + Status HandleReduce(HloInstruction* hlo) override; + + Status HandleDot(HloInstruction* hlo) override; + + Status HandleTuple(HloInstruction* hlo) override; + + Status HandleTranspose(HloInstruction* hlo) override; + + Status HandleReshape(HloInstruction* hlo) override; + + Status HandlePad(HloInstruction* hlo) override; + + Status HandleBroadcast(HloInstruction* hlo) override; + + Status HandleGetDimensionSize(HloInstruction* hlo) override; + + Status HandleSelect(HloInstruction* hlo) override; + + Status HandleConvolution(HloInstruction* hlo) override; + + Status HandleReduceWindow(HloInstruction* hlo) override; + + Status HandleSelectAndScatter(HloInstruction* hlo) override; + + Status HandleGetTupleElement(HloInstruction* hlo) override; + + Status HandleElementwiseUnary(HloInstruction* hlo) override; + + Status HandleElementwiseBinary(HloInstruction* hlo) override; + + Status HandleWhile(HloInstruction* hlo) override; + + private: + using OperandDynamicDimensionFn = std::function; + + Status ForEachOperandDynamicDimension(HloInstruction* inst, + const OperandDynamicDimensionFn&); + + // Pass through a dynamic dimension from the input to the output with the same + // value and index in the shape. This is a helper function to handle trivial + // instructions like elementwise operations. + Status PassThroughDynamicDimension(HloInstruction*); + + // The dynamic parameter bindings of this computation. + const DynamicParameterBinding& param_bindings_; + + // A pointer to DynamicDimensionInference, used to update the dynamic mapping. + DynamicDimensionInference* parent_; +}; + +Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + return UnimplementedStrCat( + "Asked to propagate a dynamic dimension from hlo ", + operand->ToString(), "@", index.ToString(), "@", dimension, + " to hlo ", hlo->ToString(), ", which is not implemented."); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleGetTupleElement( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + if (hlo->tuple_index() == index[0]) { + ShapeIndex new_index = + ShapeIndexView(index).ConsumeFront().ToShapeIndex(); + parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size); + } + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + index.push_front(operand_index); + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + int64 broadcast_dim = hlo->dimensions(dimension); + parent_->SetDynamicSize(hlo, index, broadcast_dim, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + if (operand_index != 0) { + return Unimplemented( + "Dynamic dimension on padding value is not supported"); + } + const PaddingConfig_PaddingConfigDimension& padding_config = + hlo->padding_config().dimensions(dimension); + if (padding_config.interior_padding() == 0 && + padding_config.edge_padding_low() == 0 && + padding_config.edge_padding_high() == 0) { + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); + return Status::OK(); + } else { + return Unimplemented( + "Dynamic dimension propagation on padding dimension is not " + "supported."); + } + }); +} + +Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reduce = hlo; + int64 operand_count = reduce->operand_count(); + CHECK_EQ(operand_count % 2, 0); + if (operand_index >= operand_count / 2) { + // Init values doesn't have dynamic size. + return Status::OK(); + } + if ((absl::c_count(reduce->dimensions(), dimension) != 0)) { + // Dimension is to be reduce, stop tracing. + return Status::OK(); + } + + // Find out the new dynamic dimension after reduce. + int64 dimensions_not_reduced_count = 0; + for (int i = 0; i < operand->shape().rank(); ++i) { + if (dimension == i) { + parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count, + dynamic_size); + + return Status::OK(); + } + if (absl::c_count(reduce->dimensions(), i) == 0) { + dimensions_not_reduced_count++; + } + } + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* dot = hlo; + const DotDimensionNumbers& dimension_numbers = + dot->dot_dimension_numbers(); + // A map from the operand dimensions to result dimension. + absl::flat_hash_map result_dim_mapping; + int64 current_result_dims = 0; + std::unordered_set batch_dims( + dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); + + for (int64 i : dimension_numbers.rhs_batch_dimensions()) { + result_dim_mapping[i] = current_result_dims++; + } + + for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) { + if (!absl::c_linear_search( + dimension_numbers.lhs_contracting_dimensions(), i)) { + if (operand_index == 0) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } + } + + for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) { + if (!absl::c_linear_search( + dimension_numbers.rhs_contracting_dimensions(), i) && + !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), + i)) { + if (operand_index == 1) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } + } + + // Check if the operand dim is in the result shape. If so, add another + // work item to trace that dimension. + auto iter = result_dim_mapping.find(dimension); + if (iter != result_dim_mapping.end()) { + parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size); + } + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension], + dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleConvolution( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* conv = hlo; + const ConvolutionDimensionNumbers& dimension_numbers = + conv->convolution_dimension_numbers(); + + if (operand_index == 0) { + if (dimension == dimension_numbers.input_batch_dimension()) { + parent_->SetDynamicSize(conv, {}, + dimension_numbers.output_batch_dimension(), + dynamic_size); + return Status::OK(); + } + + if (dimension == dimension_numbers.input_feature_dimension()) { + return Status::OK(); + } + } else { + if (dimension == dimension_numbers.kernel_input_feature_dimension()) { + return Status::OK(); + } + } + + return Unimplemented("Dynamic Spatial Convolution is not supported: %s", + conv->ToString()); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize( + HloInstruction*) { + // Dynamic dimension doesn't propagate through GetDimensionSize: + // + // Input: F32[x, y, z] + // | + // GetDimensionSize(1): U32[] + // + // The returned value is a scalar, which doesn't have any dynamic dimension in + // the shape (although the value contains the real size of the dynamic + // dimension of the input). + return Status::OK(); +} + +Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary( + HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary( + HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reshape = hlo; + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(operand->shape(), + reshape->shape()); + for (auto& unmodified : unmodified_dims) { + if (unmodified.first == dimension) { + parent_->SetDynamicSize(reshape, {}, unmodified.second, + dynamic_size); + return Status::OK(); + } + } + return Unimplemented( + "Dynamic Reshape on modified dimensions is yet not supported: %s", + reshape->ToString()); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleReduceWindow( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reduce_window = hlo; + const WindowDimension& window_dimension = + reduce_window->window().dimensions(dimension); + + if (!IsTrivialWindowDimension(window_dimension)) { + return Unimplemented( + "Dynamic Spatial reduce window is not supported: %s", + reduce_window->ToString()); + } + + parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* select_and_scatter = hlo; + const WindowDimension& window_dimension = + select_and_scatter->window().dimensions(dimension); + + if (!IsTrivialWindowDimension(window_dimension)) { + return Unimplemented( + "Dynamic Spatial select and scatter is not supported: %s", + select_and_scatter->ToString()); + } + + parent_->SetDynamicSize(select_and_scatter, {}, dimension, + dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) { + // While loop is handled by passing dynamic size hlos as parameters into the + // hlo while loop. This is done by replacing the original while with a new + // one. + // + // Before: + // + // op1 = ... + // op2 = ... + // op1_x = ... // dynamic dimension size of op1 + // while = while(op1, op2) + // + // + // After: + // + // op1 = ... + // op2 = ... + // op1_x = ... // dynamic dimension size of op1 + // while = while(op1, op2, op1_x) + // + // In the above graph, op_x is the bound of the dynamic dimension size of op1 + // and is wired into the while loop as new parameter. + // + // TODO(b/119843103): Once we implement dynamic bounds in XLA backend, dynamic + // bound can be propagated through native xla values instead of relying on + // additional parameter. + + // dynamic_size_to_operand_id_index_map keeps track of dynamic size operations + // to their operand ids in the new while loop. + absl::flat_hash_map + dynamic_size_to_operand_id_index_map; + + // operands_to_add collects dynamic sizes that need to be added to the while + // loop as parameters. Note that a dynamic size is ignored if it is already + // part of the parameter. i.e.: + // + // We don't do: + // + // op1 = ... + // op2 = ... + // op_x = ... // dynamic dimension size of both op1 and op2 + // while = while(op1, op2, op_x, op_x) // 4 parameters + // + // But we do: + // + // op1 = ... + // op2 = ... + // op_x = ... // dynamic dimension size of both op1 and op2 + // while = while(op1, op2, op_x) + // + // An alternative is to do this in a while loop CSE pass. + // + std::vector operands_to_add; + int64 operand_count = hlo->shape().tuple_shapes_size(); + TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + hlo, [&](HloInstruction*, ShapeIndex, int64, int64, + HloInstruction* dynamic_size) { + const HloInstruction* tuple_operand = hlo->operand(0); + for (int64 i = 0; i < tuple_operand->operand_count(); ++i) { + if (dynamic_size == tuple_operand->operand(i)) { + dynamic_size_to_operand_id_index_map[dynamic_size] = i; + return Status::OK(); + } + } + auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size); + if (iter == dynamic_size_to_operand_id_index_map.end()) { + operands_to_add.push_back(dynamic_size); + dynamic_size_to_operand_id_index_map[dynamic_size] = operand_count++; + } + return Status::OK(); + })); + + if (!operands_to_add.empty()) { + // Only replace the while loop if there are new parameters to add. + HloInstruction* old_tuple_operand = hlo->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + WhileUtil::MakeInstructionsLiveInResult result, + WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add)); + // WhileUtil creates a new while hlo and tuple. Update the dynamic size + // mapping for the newly created tuple. + HloInstruction* new_tuple_operand = + result.new_while_instr->mutable_operand(0); + parent_->CopyMapping(/*from=*/old_tuple_operand, /*to=*/new_tuple_operand); + hlo = result.new_while_instr; + } + + // We have replaced the while loop, now set the dynamic dimensions for the + // newly created while loop so that the hlos that consumes the while loop can + // see the dynamic dimensions. Also sets the dynamic parameter binding for + // running inference in the while loop. + DynamicParameterBinding binding_for_while; + TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( + hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + DynamicParameterBinding::DynamicParameter dynamic_parameter{ + operand_index, + {dynamic_size_to_operand_id_index_map[dynamic_size]}}; + DynamicParameterBinding::DynamicDimension dynamic_dimension{ + operand_index, index, dimension}; + TF_RETURN_IF_ERROR( + binding_for_while.Bind(dynamic_parameter, dynamic_dimension)); + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); + return Status::OK(); + })); + + // Run inference in while body and condition. + TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( + hlo->while_body(), binding_for_while, parent_)); + TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run( + hlo->while_condition(), binding_for_while, parent_)); + + return Status::OK(); +} + +Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) { + return param_bindings_.ForEachBinding( + [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter, + const DynamicParameterBinding::DynamicDimension& dynamic_dimension) { + if (dynamic_dimension.parameter_num != hlo->parameter_number()) { + return Status::OK(); + } + HloComputation* computation = hlo->parent(); + HloInstruction* target_parameter = + computation->parameter_instruction(dynamic_dimension.parameter_num); + + HloInstruction* dynamic_size = + computation->parameter_instruction(dynamic_parameter.parameter_num); + for (int64 i : dynamic_parameter.parameter_index) { + dynamic_size = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(dynamic_size->shape(), {i}), + dynamic_size, i)); + } + + parent_->SetDynamicSize(target_parameter, + dynamic_dimension.parameter_index, + dynamic_dimension.dimension, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension( + HloInstruction* inst, const OperandDynamicDimensionFn& fn) { + for (int64 operand_index = 0; operand_index < inst->operand_count(); + ++operand_index) { + auto iter = + parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index)); + if (iter != parent_->per_hlo_dynamic_dimensions_.end()) { + for (auto& dynamic_dimension : iter->second) { + HloInstruction* dynamic_size = parent_->GetDynamicSize( + dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim); + TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim, operand_index, + dynamic_size)); + } + } + } + return Status::OK(); +} + +void DynamicDimensionInference::CopyMapping(HloInstruction* from, + HloInstruction* to) { + auto iter = per_hlo_dynamic_dimensions_.find(from); + if (iter != per_hlo_dynamic_dimensions_.end()) { + for (auto& dynamic_dimension : iter->second) { + HloInstruction* dynamic_size = + GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim); + SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim, + dynamic_size); + } + } +} + +/* static */ +StatusOr DynamicDimensionInference::Run( + HloModule* module) { + VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString(); + DynamicDimensionInference inference(module); + TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions()); + return inference; +} + +string DynamicDimensionInference::ToString() const { + std::vector pieces; + pieces.push_back("DynamicDimensionInference: "); + for (const auto& mapping : dynamic_mapping_) { + const DynamicDimension& dynamic_dimension = mapping.first; + pieces.push_back(absl::StrFormat( + " -- instruction %s at %s has dim %lld as dynamic" + " dimension, which is represented by instruction %s", + dynamic_dimension.inst->ToString(), dynamic_dimension.index.ToString(), + dynamic_dimension.dim, mapping.second->ToString())); + } + return absl::StrJoin(pieces, "\n"); +} + +DynamicDimensionInference::DynamicDimensionInference(HloModule* module) + : module_(module) {} + +Status DynamicDimensionInference::AnalyzeDynamicDimensions() { + return DynamicDimensionInferenceVisitor::Run( + module_->entry_computation(), module_->dynamic_parameter_binding(), this); +} + +HloInstruction* DynamicDimensionInference::GetDynamicSize( + HloInstruction* inst, const ShapeIndex& index, int64 dim) const { + auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim}); + if (iter != dynamic_mapping_.end()) { + return iter->second; + } + return nullptr; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..d0f2998328f3028ccbd5b33690a514371a03b5a1 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// DynamicDimensionInference analyzes each HLO instruction in a graph and +// inferences which dimensions are dynamic and which scalar instructions +// represent the runtime real size of those dynamic dimensions. +class DynamicDimensionInference { + public: + static StatusOr Run(HloModule* module); + + string ToString() const; + + // If the dimension `dim` of instruction `inst` at `index` has a dynamic size, + // returns a scalar HloInstruction that represents the runtime size of that + // dimension. Otherwise returns nullptr. + HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index, + int64 dim) const; + + friend class DynamicDimensionInferenceVisitor; + + private: + explicit DynamicDimensionInference(HloModule* module); + + // DynamicDimension is used as a key in the dynamic key-value mapping. It + // unambiguously represents a dynamic dimension of a instruction at a given + // index. + struct DynamicDimension { + // HloInstruction that holds the dimension. + HloInstruction* inst; + // Subshape of the instruction that holds the dimension. + ShapeIndex index; + // The dimension number of the dynamic dimension at given index of a given + // instruction. + int64 dim; + + // Artifacts needed to make this struct able to be used as a `key` in absl + // maps. "friend" keywords are added so these functions can be found through + // ADL. + template + friend H AbslHashValue(H h, const DynamicDimension& m) { + return H::combine(std::move(h), m.inst, m.index, m.dim); + } + + friend bool operator==(const DynamicDimension& lhs, + const DynamicDimension& rhs) { + return lhs.inst == rhs.inst && lhs.index == rhs.index && + lhs.dim == rhs.dim; + } + }; + + // Update the dynamic mapping so that we know dimension `dim` of instruction + // `inst` at `index` has a dynamic size, and its runtime size is represented + // by a scalar instruction `size`. + void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim, + HloInstruction* size) { + dynamic_mapping_.try_emplace(DynamicDimension{inst, index, dim}, size); + auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst); + iter.first->second.emplace(DynamicDimension{inst, index, dim}); + } + + // Copies the internal mapping from instruction `from` to instruction `to`. + // This is useful when an instruction is replaced by the other during the + // inferencing process. + void CopyMapping(HloInstruction* from, HloInstruction* to); + + // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in + // module_. + Status AnalyzeDynamicDimensions(); + + // HloModule being analyzed. + HloModule* module_; + + // dynamic_mapping_ holds the result of the analysis. It maps a dynamic + // dimension to a scalar HloInstruction that represents the real dynamic size + // of the dynamic dimension. + using DynamicMapping = absl::flat_hash_map; + DynamicMapping dynamic_mapping_; + + // A convenient mapping from an hlo to the set of dynamic dimensions that it + // holds. + using PerHloDynamicDimensions = + absl::flat_hash_map>; + PerHloDynamicDimensions per_hlo_dynamic_dimensions_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..597cdf27c3318b3cf8bd5bb5f9b3239cf23a4c73 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -0,0 +1,641 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class DynamicDimensionInferenceTest : public HloTestBase { + protected: + DynamicDimensionInferenceTest() : HloTestBase() { + module_ = CreateNewVerifiedModule(); + } + + Status RunInference() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, + DynamicDimensionInference::Run(module_.get())); + + inference_ = absl::make_unique(inference); + return Status::OK(); + } + + HloComputation* GetAdd() { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + + HloComputation* GetGe() { + auto embedded_builder = HloComputation::Builder("ge"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGe, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + + std::unique_ptr module_; + std::unique_ptr inference_; + const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {}); +}; + +TEST_F(DynamicDimensionInferenceTest, ParamTest) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "param")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param")); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(param, {}, 1), param2); + EXPECT_EQ(inference_->GetDynamicSize(param, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(param2, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param")); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {1}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, GetTupleElement) { + // When data flows through GTE, the dynamic dimension size keeps the + // same, and the index has its front popped. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param")); + + auto gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, param, 0)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {1}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_THAT(inference_->GetDynamicSize(gte, {}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) { + // When data flows through elementwise, the dynamic dimension size keeps the + // same. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto* negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceTestI) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, negate, init, {0, 2}, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceTestII) { + // Same as ReduceTestI, but only reduce one dimension. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction( + HloInstruction::CreateReduce(reduce_shape, negate, init, {1}, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, DotTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0); + + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(1); + dnums.set_output_feature_dimension(0); + + Window window; + + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + zx_shape, a_param, b_param, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, TransposeTest) { + // Test the ability to trace unmodified dimensions + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/3, scalar_shape_, "size_param")); + + auto* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(output_shape, a_param, {2, 1, 0})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1); +} + +TEST_F(DynamicDimensionInferenceTest, ReshapeTest) { + // Test the ability to trace unmodified reshape dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}); + auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto* reshape = builder.AddInstruction( + HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 3})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 2), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 3), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 4), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ReshapeTestUnimplemented) { + // Test the ability to trace unmodified reshape dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}); + auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + Status status = RunInference(); + EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); +} + +TEST_F(DynamicDimensionInferenceTest, BroadcastTest) { + // Test the ability to trace broadcast dimension. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(output_shape, a_param, {1})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, WhileTest) { + // Test the ability to trace into while loops. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + auto tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); + + // Body: + // + // Param + // | | + // GTE1 GTE2 + // | | + // ADD + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto gte_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, body_param, 0)); + auto gte_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, body_param, 1)); + auto add = body_builder.AddInstruction( + HloInstruction::CreateBinary(input_shape, HloOpcode::kAdd, gte_0, gte_1)); + body_builder.AddInstruction(HloInstruction::CreateTuple({add, add})); + + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + // Entry: + // + // Param + // | + // While + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, tuple_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 0})); + + // Test that dynamic dimension inference does the right thing. A lambda is + // used here since we want to test twice by running inference again + // (idempotency). + auto test_dynamic_dimension = [&]() { + HloInstruction* while_hlo = nullptr; + // The while hlo has been replaced, find the new one. + for (HloInstruction* inst : module_->entry_computation()->instructions()) { + if (inst->opcode() == HloOpcode::kWhile) { + while_hlo = inst; + } + } + ASSERT_NE(while_hlo, nullptr); + // The original while shape has 2 parameters. With dynamic size passed in + // as an extra parameter, the tuple should have 3 elements. + EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 3); + HloInstruction* add = nullptr; + for (HloInstruction* inst : while_hlo->while_body()->instructions()) { + if (inst->opcode() == HloOpcode::kAdd) { + add = inst; + } + } + EXPECT_NE(add, nullptr); + EXPECT_NE(inference_->GetDynamicSize(add, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {0}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {1}, 0), size_param); + }; + + TF_ASSERT_OK(RunInference()); + test_dynamic_dimension(); + TF_ASSERT_OK(RunInference()); + test_dynamic_dimension(); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) { + // Test the ability to trace reduce window batch dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + + Window window; + // First dimension is unchanged. + WindowDimension* batch_dim = window.add_dimensions(); + batch_dim->set_size(1); + batch_dim->set_stride(1); + batch_dim->set_padding_low(0); + batch_dim->set_padding_high(0); + batch_dim->set_window_dilation(1); + batch_dim->set_base_dilation(1); + + // Second and third dimension are reduced. + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(2); + dim->set_stride(2); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + output_shape, a_param, init, window, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) { + // Test the ability to trace select and scatter batch dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto source_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + + Window window; + // First dimension is unchanged. + WindowDimension* batch_dim = window.add_dimensions(); + batch_dim->set_size(1); + batch_dim->set_stride(1); + batch_dim->set_padding_low(0); + batch_dim->set_padding_high(0); + batch_dim->set_window_dilation(1); + batch_dim->set_base_dilation(1); + + // Second and third dimension are reduced. + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(2); + dim->set_stride(2); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + auto* source = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, source_shape, "B")); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto* sns = builder.AddInstruction(HloInstruction::CreateSelectAndScatter( + input_shape, a_param, GetGe(), window, source, init, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{2, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(sns, {}, 0), size_param); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter.cc b/tensorflow/compiler/xla/service/dynamic_index_splitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..e34adfd2d2bbb7214cfa2da28291b133538845e5 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +StatusOr DynamicIndexSplitter::Run(HloModule* module) { + bool changed = false; + + std::vector computations = + module->MakeNonfusionComputations(); + for (HloComputation* computation : computations) { + for (HloInstruction* dynamic_op : computation->MakeInstructionPostOrder()) { + switch (dynamic_op->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + break; + default: + continue; + } + auto parent = dynamic_op->parent(); + bool is_update = dynamic_op->opcode() == HloOpcode::kDynamicUpdateSlice; + int64 num_indices = dynamic_op->operand(0)->shape().rank(); + + if (num_indices == 0) { + // If the operand rank is 0, directly replace R0 DS/DUS with the + // operand (for DS) or update (for DUS). + if (is_update) { + TF_CHECK_OK(parent->ReplaceInstruction( + dynamic_op, dynamic_op->mutable_operand(1))); + } else { + TF_CHECK_OK(parent->ReplaceInstruction( + dynamic_op, dynamic_op->mutable_operand(0))); + } + changed = true; + continue; + } + + int64 index_operand_number = Cast(dynamic_op) + ->first_index_operand_number(); + auto index_operand = dynamic_op->mutable_operand(index_operand_number); + if (ShapeUtil::IsScalar(index_operand->shape())) { + // This DS/DUS already uses scalar indices. + continue; + } + TF_RET_CHECK(index_operand->shape().rank() == 1); + auto index_element_type = index_operand->shape().element_type(); + std::vector index_array; + for (int64 dim = 0; dim < num_indices; ++dim) { + auto slice = parent->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(index_element_type, {1}), index_operand, {dim}, + {dim + 1}, {1})); + auto bitcast = parent->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(index_element_type, {}), slice)); + index_array.push_back(bitcast); + } + auto new_dynamic_op = + is_update + ? HloInstruction::CreateDynamicUpdateSlice( + dynamic_op->shape(), dynamic_op->mutable_operand(0), + dynamic_op->mutable_operand(1), absl::MakeSpan(index_array)) + : HloInstruction::CreateDynamicSlice( + dynamic_op->shape(), dynamic_op->mutable_operand(0), + absl::MakeSpan(index_array), + dynamic_op->dynamic_slice_sizes()); + TF_CHECK_OK(parent->ReplaceWithNewInstruction(dynamic_op, + std::move(new_dynamic_op))); + changed = true; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter.h b/tensorflow/compiler/xla/service/dynamic_index_splitter.h new file mode 100644 index 0000000000000000000000000000000000000000..3c12e3a4af287ad2272a08ba54cd99c2cad9d451 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Convert R1 index operands to DynamicSlice and DynamicUpdateSlice ops into +// separate scalars. +class DynamicIndexSplitter : public HloModulePass { + public: + DynamicIndexSplitter() = default; + absl::string_view name() const override { return "dynamic-index-splitter"; } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc b/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..98029d1faff7d669730f6b66e38fcefece70f0eb --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_index_splitter_test.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; +class DynamicIndexSplitterTest : public HloTestBase {}; + +TEST_F(DynamicIndexSplitterTest, DynamicSlice) { + const char* const kDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY entry (operand: s32[4,5,6], indices: s32[3]) -> s32[1,1,1] { + operand = s32[4,5,6] parameter(0) + indices = s32[3] parameter(1) + ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, indices), dynamic_slice_sizes={1,1,1} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice(op::Parameter(0), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))))); + + for (int i = 0; i < 3; ++i) { + const HloInstruction* slice = module->entry_computation() + ->root_instruction() + ->operand(i + 1) + ->operand(0); + EXPECT_EQ(slice->slice_starts(0), i); + EXPECT_EQ(slice->slice_limits(0), i + 1); + } +} + +TEST_F(DynamicIndexSplitterTest, DynamicUpdateSlice) { + const char* const kDynamicUpdateSlice = R"( + HloModule DynamicUpdatedSlice_module + + ENTRY entry (operand: s32[4,5,6], indices: s32[3], update: s32[1,1,1]) -> s32[4,5,6] { + operand = s32[4,5,6] parameter(0) + indices = s32[3] parameter(1) + update = s32[1,1,1] parameter(2) + ROOT dynamic-update-slice = s32[4,5,6] dynamic-update-slice(operand, update, indices) + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kDynamicUpdateSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::DynamicUpdateSlice(op::Parameter(0), op::Parameter(2), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))), + op::Reshape(op::Slice(op::Parameter(1))))); + + for (int i = 0; i < 3; ++i) { + const HloInstruction* slice = module->entry_computation() + ->root_instruction() + ->operand(i + 2) + ->operand(0); + EXPECT_EQ(slice->slice_starts(0), i); + EXPECT_EQ(slice->slice_limits(0), i + 1); + } +} + +TEST_F(DynamicIndexSplitterTest, AlreadyScalar) { + const char* const kDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY entry (operand: s32[4,5,6], index.0: s32[], index.1: s32[], index.2: s32[]) -> s32[1,1,1] { + operand = s32[4,5,6] parameter(0) + index.0 = s32[] parameter(1) + index.1 = s32[] parameter(2) + index.2 = s32[] parameter(3) + ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, index.0, index.1, index.2), dynamic_slice_sizes={1,1,1} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + DynamicIndexSplitter().Run(module.get())); + EXPECT_FALSE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::DynamicSlice(op::Parameter(0), op::Parameter(1), + op::Parameter(2), op::Parameter(3))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc new file mode 100644 index 0000000000000000000000000000000000000000..4db280f817141bd52e3a5b9564600a618f81aeac --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/dynamic_padder.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +// ChooseIdentityValue looks at the instruction and returns a identity value +// which, when padded, doesn't change the result of the instruction. +// +// nullopt is returned if padding doesn't need to be reset. +StatusOr ChooseIdentityValue(HloInstruction* inst) { + HloComputation* comp = inst->parent(); + // Padding on elementwise operation doesn't affect the result of the effective + // data. + if (inst->IsElementwise()) { + return nullptr; + } + + switch (inst->opcode()) { + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: { + // Because of the way we do reduce, we already require the `init` operand + // of hlo reduce instruction to be identity value. Here we reuse the + // operand. + return inst->mutable_operand(1); + } + + case HloOpcode::kConvolution: + case HloOpcode::kDot: { + // Use 0 as padding value for convolution and dot. + PrimitiveType ptype = inst->shape().element_type(); + return comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(ptype))); + } + + case HloOpcode::kPad: { + return inst->mutable_operand(1); + } + case HloOpcode::kParameter: + case HloOpcode::kGetDimensionSize: + case HloOpcode::kReshape: + case HloOpcode::kTuple: + case HloOpcode::kAllReduce: + case HloOpcode::kBroadcast: + return nullptr; + default: + return UnimplementedStrCat("Unimplimented padding for instruction: ", + inst->ToString()); + } +} + +} // namespace + +StatusOr DynamicPadder::Run(HloModule* module) { + bool changed = false; + VLOG(2) << "Pre DynamicPadder HLO:"; + XLA_VLOG_LINES(2, module->ToString()); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(module)); + + for (HloComputation* computation : module->computations()) { + for (HloInstruction* inst : computation->instructions()) { + for (int64 operand_num = 0; operand_num < inst->operand_count(); + ++operand_num) { + HloInstruction* operand = inst->mutable_operand(operand_num); + if (!operand->shape().IsArray()) { + continue; + } + for (int64 dim = 0; dim < operand->shape().rank(); ++dim) { + HloInstruction* dynamic_size = + dynamic_dimension_inference.GetDynamicSize(operand, {}, dim); + if (dynamic_size == nullptr) { + continue; + } + VLOG(1) << "Has dynamic dimension of operand" << operand_num << " @" + << dim; + TF_ASSIGN_OR_RETURN(HloInstruction * identity_value, + ChooseIdentityValue(inst)); + if (identity_value == nullptr) { + continue; + } + + // For each dimension, first generates a mask representing the + // effective area of data and padded area of data using iota and + // dynamic_size. For example, given a dimension of 7 elements and 5 + // effective elements: + // + // iota = [0, 1, 2, 3, 4, 5, 6] + // broadcast_dynamic_size = [5, 5, 5, 5, 5, 5, 5] + // mask = lt(iota, broadcast_dynamic_size) = [t, t, t, t, t, f, f] + // + // Once the mask is generated, the input data is then padded using the + // mask and pad value. + // + const Shape mask_shape = + ShapeUtil::ChangeElementType(operand->shape(), xla::U32); + const Shape pred_shape = + ShapeUtil::ChangeElementType(operand->shape(), xla::PRED); + HloInstruction* iota = computation->AddInstruction( + HloInstruction::CreateIota(mask_shape, dim)); + + HloInstruction* broadcasted_effective_size = + computation->AddInstruction(HloInstruction::CreateBroadcast( + mask_shape, dynamic_size, {})); + HloInstruction* pred = computation->AddInstruction( + HloInstruction::CreateBinary(pred_shape, HloOpcode::kLt, iota, + broadcasted_effective_size)); + + HloInstruction* broadcasted_identity_value = + computation->AddInstruction(HloInstruction::CreateBroadcast( + operand->shape(), identity_value, {})); + HloInstruction* padded = + computation->AddInstruction(HloInstruction::CreateTernary( + operand->shape(), HloOpcode::kSelect, pred, operand, + broadcasted_identity_value)); + TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded)); + operand = inst->mutable_operand(operand_num); + changed = true; + } + } + } + } + HloDCE dce; + TF_ASSIGN_OR_RETURN(changed, dce.Run(module)); + VLOG(2) << "Post DynamicPadder HLO:"; + XLA_VLOG_LINES(2, module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.h b/tensorflow/compiler/xla/service/dynamic_padder.h new file mode 100644 index 0000000000000000000000000000000000000000..509269f7f56746fa5516ad917a04221587c6dcca --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// With bounded shapes, only part of the shape contains effective data and the +// rest contains padded data, whose value can be anything depending on the +// source of the data. When a bounded shape is directly consumed by an +// instruction that collapses dimensions (reduce for example), the padding data +// would affect result of the instruction. +// +// DynamicPadder uses DynamicDimensionInference to detect bounded shapes in a +// hlo module, it then inserts certain instructions to reset the padding into an +// identity value so that in doesn't affect the result of subsequent +// instruction. For example, it'd reset the padding to 0 before a bounded shape +// is consumed by a reduce-sum. +class DynamicPadder : public HloModulePass { + public: + absl::string_view name() const override { return "dynamic_padder"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PADDER_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..55a11286e4596d87c330315322cae704fc5cd707 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -0,0 +1,152 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_padder.h" + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class DynamicPadderTest : public HloTestBase { + protected: + DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); } + + StatusOr RunPadder() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before padder"); + + DynamicPadder padder; + + return padder.Run(module_.get()); + } + + void ExpectPadded(const HloInstruction* inst) { + EXPECT_THAT(inst, + op::Select(op::Lt(op::Iota(), op::Broadcast(op::Parameter())), + ::testing::_, op::Broadcast())); + } + + HloComputation* GetScalarAddComputation() { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + + std::unique_ptr module_; + const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {}); +}; + +TEST_F(DynamicPadderTest, ReduceTest) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, negate, init, {0, 2}, GetScalarAddComputation())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunPadder().status()); + + ExpectPadded(reduce->operand(0)); +} + +TEST_F(DynamicPadderTest, ConvolutionTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0); + + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(1); + dnums.set_output_feature_dimension(0); + + Window window; + + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + zx_shape, a_param, b_param, /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunPadder().status()); + + ExpectPadded(conv->operand(0)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc index c8bfc8905064bcd7b68fe259fbcc1546ff083dbd..7f0ae692f7414dbdcccda8b287c9059bcf920df1 100644 --- a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc @@ -29,7 +29,8 @@ Status DynamicParameterBinding::Bind( } absl::optional -DynamicParameterBinding::GetBinding(const DynamicDimension& dynamic_dimension) { +DynamicParameterBinding::GetBinding( + const DynamicDimension& dynamic_dimension) const { auto param_iter = bindings_.find(dynamic_dimension); if (param_iter == bindings_.end()) { return absl::nullopt; @@ -70,7 +71,7 @@ StatusOr DynamicParameterBinding::CreateFromProto( int64 target_param_num = binding.target_param_num(); ShapeIndex target_param_index(binding.target_param_index().begin(), binding.target_param_index().end()); - int64 target_dim_num = binding.target_param_num(); + int64 target_dim_num = binding.target_param_dim_num(); TF_RETURN_IF_ERROR( result.Bind(DynamicParameter{dynamic_param_num, dynamic_param_index}, @@ -111,7 +112,8 @@ Status DynamicParameterBinding::Verify(const HloModule& module) const { return ForEachBinding([&](const DynamicParameter& dynamic_parameter, const DynamicDimension& dynamic_dimension) -> Status { - TF_RET_CHECK(dynamic_parameter.parameter_num < entry->num_parameters()); + TF_RET_CHECK(dynamic_parameter.parameter_num >= 0 && + dynamic_parameter.parameter_num < entry->num_parameters()); TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters()); TF_RET_CHECK(ShapeUtil::IndexIsValid( entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(), @@ -121,10 +123,11 @@ Status DynamicParameterBinding::Verify(const HloModule& module) const { dynamic_dimension.parameter_index)); TF_RET_CHECK( dynamic_dimension.dimension < - ShapeUtil::Rank(ShapeUtil::GetSubshape( + ShapeUtil::GetSubshape( entry->parameter_instruction(dynamic_dimension.parameter_num) ->shape(), - dynamic_dimension.parameter_index))); + dynamic_dimension.parameter_index) + .rank()); return Status::OK(); }); } diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h index dd474d8eed1b2c30ddb8f624a864198c74eacaba..57af2c43d3c65f7340e6a9f04e5abbf052ebceea 100644 --- a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h @@ -89,7 +89,7 @@ class DynamicParameterBinding { // // Returns nullopt if the binding is not set. absl::optional GetBinding( - const DynamicDimension& dynamic_dimension); + const DynamicDimension& dynamic_dimension) const; using BindingFn = std::functionToProto(); + TF_ASSERT_OK_AND_ASSIGN(*binding, + DynamicParameterBinding::CreateFromProto(proto)); + } +}; TEST_F(DynamicParameterBindingTest, SimpleBinding) { // 'b' is a dynamic shape; 'a' represents the real size of b's first @@ -56,15 +64,20 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}}, DynamicParameterBinding::DynamicDimension{1, {}, 0})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, - /*parameter_index=*/{}, - /*dimension=*/0}); - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({})); - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, + /*parameter_index=*/{}, + /*dimension=*/0}); + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + test(binding); + SerializeAndDeserialize(&binding); + test(binding); } TEST_F(DynamicParameterBindingTest, TupleBinding) { @@ -89,16 +102,21 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, DynamicParameterBinding::DynamicDimension{0, {1}, 0})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({0})); - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + test(binding); + SerializeAndDeserialize(&binding); + test(binding); } TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) { @@ -127,26 +145,35 @@ ENTRY main { binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, DynamicParameterBinding::DynamicDimension{0, {1}, 1})); - absl::optional param = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - - EXPECT_TRUE(param); - EXPECT_EQ(param->parameter_num, 0); - EXPECT_EQ(param->parameter_index, ShapeIndex({0})); - - absl::optional param2 = - binding.GetBinding( - DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, - /*parameter_index=*/{1}, - /*dimension=*/0}); - EXPECT_TRUE(param2); - EXPECT_EQ(param2->parameter_num, 0); - EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); - - TF_EXPECT_OK(binding.Verify(*module)); + auto test = [&](const DynamicParameterBinding& binding) { + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + + absl::optional param2 = + + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + EXPECT_TRUE(param2); + EXPECT_EQ(param2->parameter_num, 0); + EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); + }; + + test(binding); + + SerializeAndDeserialize(&binding); + + // Test the binding again after deserialization. + test(binding); } } // namespace diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 6f1f95f2e9082649b6ca9cc0da5c238e15b77c10..727e0bfa52d45b6f8c67d7d04613e4865f18a53c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -812,11 +812,14 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); + auto zero = llvm::ConstantFP::get(a->getType(), 0); auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto one = llvm::ConstantFP::get(a->getType(), 1); auto half_c = FMul(one_half, c); TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, EmitPow(component_type, aa_p_bb, half_c)); + auto neg_d = FNeg(d); TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); @@ -828,7 +831,13 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); - return EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)); + // 0^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see + // Branch Cuts for Complex Elementary Functions or Much Ado About + // Nothing's Sign Bit, W. Kahan, Section 10. + return Select( + And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)), + EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero), + EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q))); } default: return Unimplemented("binary complex op '%s'", @@ -1327,9 +1336,9 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If implicit broadcast is needed, the source dimensions that are broadcast // have index 0. - CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape())); + CHECK_EQ(operand_shape.rank(), hlo.shape().rank()); llvm_ir::IrArray::Index source_index(target_index.GetType()); - for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) { + for (int64 i = 0; i < hlo.shape().rank(); ++i) { if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { source_index.push_back(target_index[i]); } else { @@ -1750,7 +1759,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( const llvm_ir::IrArray::Index& index) { // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + const int64 rank = input_hlo->shape().rank(); // Use the same index type for all tensor accesses in the same kernel. llvm::Type* index_type = index.GetType(); llvm_ir::IrArray::Index slice_start_index(index_type, rank); @@ -1758,9 +1767,18 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(hlo->operand(1))(dim_index)); + // TODO(b/118437727): Remove the R1 path. + llvm::Value* start_index_value; + if (hlo->operand(1)->shape().rank() == 1) { + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); + TF_ASSIGN_OR_RETURN(start_index_value, + operand_to_generator.at(hlo->operand(1))(dim_index)); + } else { + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + start_index_value, + operand_to_generator.at(hlo->operand(1 + i))(zero_index)); + } // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) @@ -1893,7 +1911,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* update_hlo = hlo->operand(1); const HloInstruction* start_hlo = hlo->operand(2); // Calculate slice start/end indices. - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + const int64 rank = input_hlo->shape().rank(); llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank); llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank); // Slice intersection gathers (ANDs) conditions on all ranks for which @@ -1905,9 +1923,19 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(start_hlo)(dim_index)); + + llvm::Value* start_index_value; + // TODO(b/118437727): Remove the R1 path. + if (hlo->operand(2)->shape().rank() == 1) { + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); + TF_ASSIGN_OR_RETURN(start_index_value, + operand_to_generator.at(hlo->operand(2))(dim_index)); + } else { + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + start_index_value, + operand_to_generator.at(hlo->operand(2 + i))(zero_index)); + } // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) @@ -2225,7 +2253,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( auto* iota = Cast(hlo); PrimitiveType element_type = iota->shape().element_type(); IrArray::Index elem_index = - ShapeUtil::Rank(iota->shape()) > 1 + iota->shape().rank() > 1 ? target_index.SourceIndexOfBroadcast( iota->shape(), ShapeUtil::MakeShapeWithDescendingLayout( diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 01cef499665c050d4453382289168276028e1d26..590942cddcdd138981ee829f090ae17b0d038e1a 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -153,10 +153,9 @@ static StatusOr> GatherLoopBody( dim_numbers.index_vector_dim() == gather.operand(1)->shape().dimensions_size()); - TF_ASSIGN_OR_RETURN( - HloInstruction * induction_var_as_vector, + HloInstruction* induction_var_as_vector = MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, - /*result_shape_bounds=*/{1})); + /*result_shape_bounds=*/{1}); HloInstruction* index_vector; @@ -222,7 +221,7 @@ static StatusOr> GatherLoopBody( {operand, start_indices, updated_accumulator}}; } -static StatusOr CreateGatherLoopAccumulatorInitValue( +static HloInstruction* CreateGatherLoopAccumulatorInitValue( HloComputation* computation, PrimitiveType element_type, absl::Span slice_sizes, int64 gather_loop_trip_count, const GatherDimensionNumbers& dim_numbers) { @@ -332,12 +331,10 @@ StatusOr GatherExpander::ExpandGather( CHECK_EQ(gather_loop_trip_count, canonical_start_indices->shape().dimensions(0)); - TF_ASSIGN_OR_RETURN( - HloInstruction * accumulator_init, - CreateGatherLoopAccumulatorInitValue( - computation, output_shape.element_type(), - gather_instr->gather_slice_sizes(), gather_loop_trip_count, - gather_instr->gather_dimension_numbers())); + HloInstruction* accumulator_init = CreateGatherLoopAccumulatorInitValue( + computation, output_shape.element_type(), + gather_instr->gather_slice_sizes(), gather_loop_trip_count, + gather_instr->gather_dimension_numbers()); StatusOr> gather_loop_result_or_error = WhileUtil::MakeCountedLoop( diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index a3102368cb1dba15da7422337666d278cef775ab..e1ea5c39d58b6d23b076740626ca0ad63dc341ee 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -89,7 +89,7 @@ ENTRY main { // an implementation detail from WhileUtil::MakeCountedLoop). const Shape& while_shape = while_instr->shape(); - ASSERT_TRUE(ShapeUtil::IsTuple(while_shape)); + ASSERT_TRUE(while_shape.IsTuple()); ASSERT_EQ(ShapeUtil::TupleElementCount(while_shape), 4); EXPECT_TRUE(ShapeUtil::SameDimensions( diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index bec02e14f951c6d905b7329be5c02896984279d0..7d450f4b53cdea209f2ef10ba785be6ec3b8bf8d 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -83,7 +83,7 @@ Status GenericTransferManager::TransferLiteralFromDeviceInternal( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { TF_RETURN_IF_ERROR(executor->SynchronousMemcpyD2H( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), @@ -120,7 +120,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( device_buffer.on_host_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); - if (ShapeUtil::IsArray(device_subshape)) { + if (device_subshape.IsArray()) { TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); // Element is array-shaped: transfer array data to device buffer. diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index bfd1b6cb1492f5cb709e2ecefe73782094e26f5e..85fb2dd47abdad7073bf15a2f8b974a3ae0f01e4 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3,6 +3,11 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_test") licenses(["notice"]) # Apache 2.0 @@ -24,12 +29,6 @@ filegroup( ]), ) -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", -) - xla_proto_library( name = "backend_configs", srcs = ["backend_configs.proto"], @@ -94,8 +93,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -135,6 +134,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm//:core", @@ -263,7 +264,9 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -362,6 +365,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -551,6 +555,44 @@ cc_library( ], ) +cc_library( + name = "gpu_sanitize_constant_names", + srcs = ["gpu_sanitize_constant_names.cc"], + hdrs = ["gpu_sanitize_constant_names.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "gpu_sanitize_constant_names_test", + srcs = ["gpu_sanitize_constant_names_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_sanitize_constant_names", + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "fusion_merger", srcs = ["fusion_merger.cc"], @@ -675,6 +717,7 @@ cc_library( ":gpu_hlo_schedule", ":gpu_hlo_support_checker", ":gpu_layout_assignment", + ":gpu_sanitize_constant_names", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", @@ -694,6 +737,9 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", + "//tensorflow/compiler/xla/service:convolution_group_converter", + "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -711,6 +757,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:sort_simplifier", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -724,6 +771,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor/cuda:cuda_diagnostics", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1004,14 +1052,10 @@ cc_library( srcs = ["variadic_op_splitter.cc"], hdrs = ["variadic_op_splitter.h"], deps = [ - ":ir_emission_utils", - "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 528209abc75777440163c2e1512658b8ad36315b..eb59ee5a1d47b6b706ef3f53a76069b3538eb6b7 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -57,16 +58,16 @@ StatusOr> BufferAllocations::Builder::Build( // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. - if (registered_buffers_.count(i)) { - se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i); - if (reinterpret_cast(address.opaque()) % expected_alignment != + if (se::DeviceMemoryBase* address = + tensorflow::gtl::FindOrNull(registered_buffers_, i)) { + if (reinterpret_cast(address->opaque()) % expected_alignment != 0) { return InternalError( "Address of registered buffer %d must be a multiple of %x, but " "was %p", - i, kEntryParameterAlignBytes, address.opaque()); + i, kEntryParameterAlignBytes, address->opaque()); } - buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i)); + buffer_allocations->SetBuffer(i, *address); continue; } diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index 14186b8faa68ad8492ea4863fcd7bd746e2eae48..9413ac2cff7c8d3ec4be6662569c580060bf1173 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -52,7 +53,8 @@ class BufferAllocations { DeviceMemoryAllocator* memory_allocator); private: - std::map registered_buffers_; + absl::flat_hash_map + registered_buffers_; }; ~BufferAllocations(); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 6d6780fa1c7b0c636eb771c40e74f074cd8c4c4b..309b0aca64954e64509d731dce28ce9d8da4ee43 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -146,7 +146,8 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. StatusOr -CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { +CudnnConvAlgorithmPicker::PickBestAlgorithm( + const HloCustomCallInstruction* instr) { // TODO(timshen): for now only check fp16. It can be expanded to other types, // with some work on the HLO routines. const bool cross_check_enabled = @@ -249,12 +250,13 @@ CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - backend_config.set_algorithm(alg.algo_id()); - backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled()); - TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config)); + // Use assignment instead of brace-list to make GCC 4.9 happy. + RunConvOptions options; + options.profile_result = &profile_result; + options.algo_override = alg; bool launch_ok = RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, - &scratch_allocator, &stream, &profile_result) + &scratch_allocator, &stream, options) .ok(); if (launch_ok && profile_result.is_valid()) { diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h index 642af787afc71586d722ecc7e529ed8b3fa64d33..4991db0948589e479a202f4082d96df275f6e088 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h @@ -56,7 +56,8 @@ class CudnnConvAlgorithmPicker : public HloModulePass { StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr PickBestAlgorithm(HloCustomCallInstruction* instr); + StatusOr PickBestAlgorithm( + const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc index 5aa4f839f4be5f1060480fea98775f8ffada0bdd..958e0b9c6e7b7885f87b90d61ee5b3bbf6ab2702 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc @@ -50,10 +50,10 @@ static HloInstruction* PadInstruction(HloInstruction* instr, auto* zero = comp->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); - PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); + PaddingConfig pad_config = MakeNoPaddingConfig(shape.rank()); bool added_padding = false; - for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) { + for (int64 dim = 0; dim < shape.rank(); ++dim) { if (shape.dimensions(dim) == new_shape.dimensions(dim)) { continue; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc index d7829045cc127deaa4c2c9b705dca5285d704af2..17d0f7aa7bf6031148aae79f74f7878d6fca9574 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc @@ -43,13 +43,14 @@ bool IsForwardConvolutionCanonical(const HloInstruction& conv) { // dilation), returns kPad and/or kSlice instructions that explicitly apply the // padding; otherwise returns the original input operand. When there is both // positive padding (including dilation) and negative padding, we insert both -// kPad and kSlice. +// kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved +// into a kPad or kSlice op. HloInstruction* MaybePaddedAndSlicedInput( - const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums, + Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* input) { HloComputation* computation = input->parent(); - if (!window_util::HasSymmetricPadding(conv_window) || - window_util::HasBaseDilation(conv_window)) { + if (!window_util::HasSymmetricPadding(*conv_window) || + window_util::HasBaseDilation(*conv_window)) { // If padding is uneven or has dilation, we insert a kPad instruction that // applies positive padding and dilation. // @@ -62,12 +63,21 @@ HloInstruction* MaybePaddedAndSlicedInput( MakeNoPaddingConfig(input->shape().dimensions_size()); for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { int64 dim = conv_dnums.input_spatial_dimensions(i); - padding_config.mutable_dimensions(dim)->set_edge_padding_low( - std::max(0LL, conv_window.dimensions(i).padding_low())); - padding_config.mutable_dimensions(dim)->set_edge_padding_high( - std::max(0LL, conv_window.dimensions(i).padding_high())); - padding_config.mutable_dimensions(dim)->set_interior_padding( - conv_window.dimensions(i).base_dilation() - 1); + if (conv_window->dimensions(i).padding_low() > 0) { + padding_config.mutable_dimensions(dim)->set_edge_padding_low( + conv_window->dimensions(i).padding_low()); + conv_window->mutable_dimensions(i)->set_padding_low(0); + } + if (conv_window->dimensions(i).padding_high() > 0) { + padding_config.mutable_dimensions(dim)->set_edge_padding_high( + conv_window->dimensions(i).padding_high()); + conv_window->mutable_dimensions(i)->set_padding_high(0); + } + if (conv_window->dimensions(i).base_dilation() != 1) { + padding_config.mutable_dimensions(dim)->set_interior_padding( + conv_window->dimensions(i).base_dilation() - 1); + conv_window->mutable_dimensions(i)->set_base_dilation(1); + } } PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction( @@ -75,7 +85,7 @@ HloInstruction* MaybePaddedAndSlicedInput( input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } - if (window_util::HasNegativePadding(conv_window)) { + if (window_util::HasNegativePadding(*conv_window)) { // If the window has negative padding, insert a kSlice that explicitly // applies negative padding. // @@ -89,10 +99,14 @@ HloInstruction* MaybePaddedAndSlicedInput( int64 dim = conv_dnums.input_spatial_dimensions(i); // If dimension "dim" has negative padding, increase the start index or // decrement the limit index by the amount of negative padding. - start_indices[dim] += - std::max(0LL, -conv_window.dimensions(i).padding_low()); - limit_indices[dim] -= - std::max(0LL, -conv_window.dimensions(i).padding_high()); + if (conv_window->dimensions(i).padding_low() < 0) { + start_indices[dim] += -conv_window->dimensions(i).padding_low(); + conv_window->mutable_dimensions(i)->set_padding_low(0); + } + if (conv_window->dimensions(i).padding_high() < 0) { + limit_indices[dim] -= -conv_window->dimensions(i).padding_high(); + conv_window->mutable_dimensions(i)->set_padding_high(0); + } } input = @@ -140,25 +154,22 @@ bool CudnnConvPaddingLegalization::CanonicalizeForwardConvolution( // Insert slices and/or pads between the convolution and its input and/or // kernel operand. + Window new_conv_window = conv->window(); HloInstruction* new_input = MaybePaddedAndSlicedInput( - conv->window(), conv->convolution_dimension_numbers(), + &new_conv_window, conv->convolution_dimension_numbers(), conv->mutable_operand(0)); HloInstruction* new_kernel = - MaybePaddedKernel(conv->window(), conv->convolution_dimension_numbers(), + MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(), conv->mutable_operand(1)); - // Remove the padding from convolution's window field. These paddings are - // made explicit with the inserted pads. - Window new_conv_window = conv->window(); + // Remove the window dilation from convolution's window field. These paddings + // are made explicit with the pads inserted by MaybePaddedKernel(). for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) { WindowDimension* dim = new_conv_window.mutable_dimensions(i); // The size of the kernel may have changed so update the Window to match. dim->set_size(new_kernel->shape().dimensions( conv->convolution_dimension_numbers().kernel_spatial_dimensions(i))); - dim->set_padding_low(0); - dim->set_padding_high(0); - dim->set_base_dilation(1); dim->set_window_dilation(1); } @@ -208,7 +219,7 @@ bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( Window new_backward_conv_window = backward_conv->window(); // input_padding_config is the config of the kPad to be inserted. PaddingConfig input_padding_config = - MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); + MakeNoPaddingConfig(input->shape().rank()); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index 443883a89f66a747def1049bc5afb53fec3c2409..dbcdc2b075bc72f3194af8e555faabb1511376e0 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -109,9 +109,11 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) { auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape( activations->shape(), gradients->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_filter_) + /*batch_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, /*feature_group_count=*/1, conv_window, + activations, gradients, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); OpMetadata metadata; @@ -147,9 +149,11 @@ TEST_F(CudnnConvRewriterTest, builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape( activations->shape(), gradients->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_filter_) + /*batch_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, /*feature_group_count=*/1, conv_window, + activations, gradients, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -179,7 +183,7 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -209,7 +213,7 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -238,7 +242,7 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -283,13 +287,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveEvenPadding) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, - /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window, - conv_dnums, DefaultPrecisionConfig(2))); + /*rhs=*/reverse_kernel, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, conv_dnums, + DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( conv->shape(), ShapeInference::InferConvolveShape( output->shape(), reverse_kernel->shape(), - /*feature_group_count=*/1, conv_window, conv_dnums) + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, conv_dnums) .ValueOrDie())); auto module = CreateNewVerifiedModule(); @@ -332,10 +338,12 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1Filter) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, + /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); @@ -365,11 +373,12 @@ TEST_F(CudnnConvRewriterTest, builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape( output->shape(), kernel->shape(), /*feature_group_count=*/1, - default_conv_window_, tf_default_dnums_for_backward_input_) + /*batch_group_count=*/1, default_conv_window_, + tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, - default_conv_window_, tf_default_dnums_for_backward_input_, - DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, default_conv_window_, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -415,15 +424,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -465,15 +474,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -519,15 +528,15 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { forward_conv_col_dim->set_base_dilation(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); const HloComputation* entry_computation = @@ -574,15 +583,15 @@ TEST_F(CudnnConvRewriterTest, forward_conv_col_dim->set_padding_high(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, + /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, /*batch_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = @@ -599,7 +608,7 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { Array4D constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); string constant_str = - LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); + LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape(); const string module_str = absl::StrFormat(R"( HloModule test diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index 3425e1b4942aaf1011ba1bf1c50dd7e79c1f9807..b628f27f4b2ba8ccf17fd531d8a0c25cb99d9396 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -395,32 +395,36 @@ Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { + RunConvOptions options) { ScratchBufAllocator scratch_allocator(scratch_buf); return RunCudnnConv(conv, operand_buffers, result_buffer, &scratch_allocator, - stream, profile_result); + stream, options); } Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { + RunConvOptions options) { TF_ASSIGN_OR_RETURN(CudnnConvParams params, GetCudnnConvParams(conv, operand_buffers, result_buffer)); + if (options.algo_override) { + params.algorithm = AlgorithmConfig(*options.algo_override); + } + PrimitiveType output_primitive_type = conv->shape().tuple_shapes(0).element_type(); switch (output_primitive_type) { case F16: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); case F32: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); case F64: return RunCudnnConvImpl(params, scratch_allocator, stream, - profile_result); + options.profile_result); default: LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h index edbc75a94a1238540390b93f0fa5217852c7781f..25b2461ca61251c6cb7b89b1f91da0f1636a3647 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h @@ -28,6 +28,14 @@ limitations under the License. namespace xla { namespace gpu { +struct RunConvOptions { + // Nullable output-parameter pointer for profiling results. + se::dnn::ProfileResult* profile_result = nullptr; + + // Use this algorithm, instead of the one from the instruction. + absl::optional algo_override; +}; + // This file contains low-level routines for running cudnn convolutions. // Calls into cudnn to run the specified convolution. @@ -46,13 +54,13 @@ Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); + RunConvOptions = {}); Status RunCudnnConv(const HloCustomCallInstruction* conv, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); + RunConvOptions = {}); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 6dcdaf1cfe06e446deed847aaf29088a7ed10e13..ffd4214958275dc79bbcb060328893f8b68c737a 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -161,6 +161,16 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); + HloOpcode opcode = op->opcode(); + + if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() && + (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) { + return llvm_ir::EmitCallToIntrinsic( + opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum + : llvm::Intrinsic::minnum, + {lhs_value, rhs_value}, {lhs_value->getType()}, b_); + } + switch (op->opcode()) { case HloOpcode::kRemainder: { return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value}, @@ -298,9 +308,11 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( false); // No variadic arguments. // Declares the callee if it is not declared already. - llvm::Function* callee = llvm::cast( - b_->GetInsertBlock()->getModule()->getOrInsertFunction( - llvm_ir::AsStringRef(callee_name), callee_type)); + llvm::Function* callee = llvm::dyn_cast( + b_->GetInsertBlock() + ->getModule() + ->getOrInsertFunction(llvm_ir::AsStringRef(callee_name), callee_type) + .getCallee()); for (auto attribute : attributes) { callee->addFnAttr(attribute); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 470457935acacb8940af241dadb393d770786939..91930eccdff94bb2fc85636f3a4b2d661c618d87 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -35,7 +35,7 @@ namespace { // Traverses users of tuple shape, adding leaf instructions to 'instructions'. void MaybeResolveTupleElements(HloInstruction* instruction, std::vector* instructions) { - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { for (auto tuple_user : instruction->users()) { MaybeResolveTupleElements(tuple_user, instructions); } diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 27f07b1d58125092c1ed6734b238e4ae0f11c4aa..86c9bc6a345047fb5329af0be45c8981cc427f50 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -206,6 +206,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { return &DoGemm; case C64: return &DoGemm>; + case C128: + return &DoGemm>; default: LOG(FATAL) << "Unsupported type."; } @@ -221,6 +223,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) return &DoGemmWithAlgorithm; case C64: return &DoGemmWithAlgorithm>; + case C128: + return &DoGemmWithAlgorithm>; default: LOG(FATAL) << "Unsupported type."; } @@ -235,6 +239,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { return &DoGemmAutotune; case C64: return &DoGemmAutotune>; + case C128: + return &DoGemmAutotune>; default: LOG(FATAL) << "Unsupported type."; } @@ -255,6 +261,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { return se::blas::ComputationType::kF64; case C64: return se::blas::ComputationType::kComplexF32; + case C128: + return se::blas::ComputationType::kComplexF64; default: LOG(FATAL) << "Unsupported type."; } @@ -315,8 +323,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), dim_nums.rhs_batch_dimensions_size()); - CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, - ShapeUtil::Rank(output_shape_)); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank()); int64 row_dim = dim_nums.lhs_batch_dimensions_size(); int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index ae2e718db29803a085401969a7d9b09abf690a6c..434060ad89dac7ad65c790c8c0a7f3d6ad62a25a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -218,7 +218,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); - CHECK(ShapeUtil::IsArray(literal.shape())); + CHECK(literal.shape().IsArray()); if (!ShouldEmitLiteralInLlvmIr(literal)) { VLOG(3) << "H2D memcpy for constant with shape " << ShapeUtil::HumanString(literal.shape()); @@ -310,12 +310,34 @@ StatusOr GpuExecutable::ExecuteOnStream( TF_ASSIGN_OR_RETURN( const BufferAllocation::Slice slice, this->assignment_->GetUniqueSlice(src_hlo, sources[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); se::DeviceMemoryBase src_base = buffer_allocations->GetDeviceAddress(slice.index()); CHECK(!src_base.is_null() || src_base.size() == 0); - *device_memory = src_base; + if (!slice.allocation()->is_entry_computation_parameter()) { + // If the buffer coming out of the result is from a parameter, it + // means the caller aliased some parameter buffer to an output one + // (via the HloInputOutputAliasConfig API). If that is the case, the + // caller will receive a partially complete scoped shaped buffer, + // which they will have to fill up on return. + // Unfortunately the interface to the execute APIs are ShapedBuffer + // pointer based, which assumes caller ownership, and hence a buffer + // coming from there cannot be part of the new ScopedShapedBuffer we + // create for the result (which assumes ownership). + *device_memory = src_base; + } else { + const HloInputOutputAliasConfig& input_output_alias = + module().input_output_alias_config(); + auto output_alias = input_output_alias.GetAliasedOutput( + slice.allocation()->parameter_number(), + slice.allocation()->param_shape_index()); + CHECK(output_alias) + << "Ouput buffer is coming from parameter " + << slice.allocation()->parameter_number() << " at index " + << slice.allocation()->param_shape_index() + << ", but no alias exists"; + CHECK_EQ(*output_alias, index); + } buffers_in_result.insert(src_base); return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 452e763a8eaadc805cd3a3859a68e2a31598fd36..842ba2fdcd31a451cec1be543e102e0a46077f38 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -42,15 +42,13 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, int64 max_rank = -1; const Layout* max_rank_layout; for (HloInstruction* param : params) { - if (ShapeUtil::IsArray(param->shape()) && - ShapeUtil::Rank(param->shape()) > max_rank) { - max_rank = ShapeUtil::Rank(param->shape()); + if (param->shape().IsArray() && param->shape().rank() > max_rank) { + max_rank = param->shape().rank(); max_rank_layout = ¶m->shape().layout(); } } return absl::c_all_of(params, [&](HloInstruction* param) { - return (!ShapeUtil::IsArray(param->shape())) || - (ShapeUtil::Rank(param->shape()) < max_rank) || + return (!param->shape().IsArray()) || (param->shape().rank() < max_rank) || (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); }); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index f59da2caa18646676297e66dd329c66fb5fddf1b..58bdd4209a2315cdb7d29e920faded4d1a6a5876 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -196,9 +196,9 @@ Status GpuLayoutAssignment::AddBackendConstraints( CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), dim_nums.rhs_batch_dimensions_size()); CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, - ShapeUtil::Rank(instruction->shape())); + instruction->shape().rank()); for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { - CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2); + CHECK_LT(batch_dim, instruction->shape().rank() - 2); } // Set both inputs and the output to default layout. @@ -215,18 +215,18 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, instruction)); } else if (instruction->opcode() == HloOpcode::kSort && - ShapeUtil::Rank(instruction->operand(0)->shape()) > 1) { + instruction->operand(0)->shape().rank() > 1) { // Make sure that all the operands and the output(s) have the same layout. Shape keys_shape = instruction->operand(0)->shape(); Layout keys_layout = - LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(keys_shape)); + LayoutUtil::GetDefaultLayoutForRank(keys_shape.rank()); for (int64 i = 0; i < instruction->operand_count(); ++i) { Shape shape = instruction->operand(i)->shape(); *shape.mutable_layout() = keys_layout; TF_RETURN_IF_ERROR( constraints->SetOperandLayout(shape, instruction, i)); const LogicalBuffer* output_buffer; - if (ShapeUtil::IsArray(instruction->shape())) { + if (instruction->shape().IsArray()) { TF_ASSIGN_OR_RETURN( output_buffer, constraints->points_to_analysis().GetBufferDefinedAt(instruction, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 2ffc8bfb49b205dced0d540ba72426e72d95e596..29756d27260b0f41b2dd4b649ea9b1610ff90268 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -369,7 +369,7 @@ TEST_F(LayoutAssignmentTest, SortLayout) { const char* hlo_text = R"( HloModule SortLayout ENTRY sort { - keys = f32[3,2]{0,1} constant(f32[3,2]{0,1}{{0,1},{0,1},{0,1}}) + keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}}) values = f32[2,3]{1,0} parameter(0) transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0} ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose), diff --git a/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.cc b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e38ceca18de30e0e1fa75a7a4bd865e000b7d22 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.cc @@ -0,0 +1,70 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace gpu { + +StatusOr GpuSanitizeConstantNames::Run(HloModule* module) { + bool changed = false; + + NameUniquer instr_name_uniquer(/*separator=*/"_"); + // Collect the names used for the non-constant HLO instructions.+ + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kConstant) { + continue; + } + + const string& old_name = instr->name(); + instr->UniquifyName(&instr_name_uniquer); + CHECK_EQ(old_name, instr->name()); + } + } + + // Sanitize the names for the constant HLO instructions and make them unique. + // This is not merged into the above loop because we don't want this pass to + // change the names of non-constant instructions, that is, if a constant HLO + // conflicts with a non-constant HLO, we change the name of the constant HLO + // even though the non-constant HLO comes after in the HLO module. + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() != HloOpcode::kConstant) { + continue; + } + string sanitized_name = llvm_ir::SanitizeConstantName(*instr); + instr->SetAndSanitizeName(sanitized_name); + instr->UniquifyName(&instr_name_uniquer); + changed = true; + } + } + + return changed; +} // namespace gpu + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h new file mode 100644 index 0000000000000000000000000000000000000000..8d583d047e25698e86032020b7fc20df87f5ab68 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Sanitizes HLO instruction names for the GPU backend. Currently, it only +// replaces . and - in the HLO constant instruction names with _ to please the +// LLVM PTX backend. +class GpuSanitizeConstantNames : public HloModulePass { + public: + absl::string_view name() const override { return "sanitize-constant-names"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5adee8cc61f18f356406d8c089dd43565957739 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using SanitizeConstantNamesTest = HloTestBase; + +TEST_F(SanitizeConstantNamesTest, InstructionNameWithHyphenSanitized) { + const char *const kHloString = R"( + HloModule HyphenInInstructionName + ENTRY kernelEntry { + ROOT equal-to = s32[2]{0} constant({42, 73}) + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kHloString)); + + EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).ValueOrDie()); + HloInstruction *root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "equal_to"); +} + +TEST_F(SanitizeConstantNamesTest, InstructionNameWithDotSanitized) { + const char *const kHloString = R"( + HloModule HyphenInInstructionName + ENTRY kernelEntry { + ROOT equal.to = s32[2]{0} constant({42, 73}) + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kHloString)); + + EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).ValueOrDie()); + HloInstruction *root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "equal_to"); +} + +TEST_F(SanitizeConstantNamesTest, BufferSanitizedNameCollisionResolved) { + const char *const kHloString = R"( + HloModule BufferSanitizedName + ENTRY kernelEntry { + equal.to = s32[2]{0} constant({42, 73}) + equal-to = s32[2]{0} constant({67, 3}) + ROOT equal_to = s32[2]{0} add(equal.to, equal-to) + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kHloString)); + + EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).ValueOrDie()); + EXPECT_THAT(FindInstruction(module.get(), "equal_to_1"), op::Constant()); + EXPECT_THAT(FindInstruction(module.get(), "equal_to_2"), op::Constant()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index f3c274429242d5c989146d14ea523b5910408cff..8c6a6914792a96ab517fa5f20ff2215e4785490e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -59,7 +59,7 @@ Status GpuTransferManager::TransferLiteralToInfeed( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( shape, [&](const Shape& literal_subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(literal_subshape)) { + if (literal_subshape.IsArray()) { int64 tuple_element_size = GetByteSizeRequirement(literal_subshape); TF_ASSIGN_OR_RETURN( *buffer_tree.mutable_element(index), @@ -126,13 +126,12 @@ static void ShapeTreeToLiteral( ShapeTree>* shape_tree, ShapeIndex* index) { const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index); - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); return; } - CHECK(ShapeUtil::IsTuple(shape)) - << ShapeUtil::HumanStringWithLayout(shape); + CHECK(shape.IsTuple()) << ShapeUtil::HumanStringWithLayout(shape); const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); index->push_back(0); for (int64 i = 0; i < tuple_element_count; ++i) { @@ -158,7 +157,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( std::unique_ptr* buffer) { const Shape& shape = ShapeUtil::GetSubshape(literal_shape, index); // Do not transfer tuple index buffers. - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return; } *buffer = absl::make_unique( diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 51627402b45f594dab3480129ba182d54d01b811..69aaaceca112364a4fd562f6a5eff1629fd3fc54 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -45,10 +46,10 @@ void HloToIrBindings::EmitBasePointersForHlos( // An HLO can have duplicated operands. This data structure remembers which // operand HLOs are already bound to avoid rebinding the same HLO. - std::set already_bound_for_this_function; + absl::flat_hash_set already_bound_for_this_function; auto arg_iter = function->arg_begin(); for (const HloInstruction* io_hlo : io_hlos) { - if (!already_bound_for_this_function.count(io_hlo)) { + if (!already_bound_for_this_function.contains(io_hlo)) { if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter)); } else { @@ -63,7 +64,7 @@ void HloToIrBindings::EmitBasePointersForHlos( temp_buffer_base_->setName("temp_buffer"); for (const HloInstruction* non_io_hlo : non_io_hlos) { - if (already_bound_for_this_function.count(non_io_hlo)) { + if (already_bound_for_this_function.contains(non_io_hlo)) { continue; } already_bound_for_this_function.insert(non_io_hlo); @@ -280,7 +281,7 @@ string HloToIrBindings::ToString() const { StrAppend(&s, " ", instr->ToString()); const ShapeTree& shape_tree = it->second; - if (!ShapeUtil::IsTuple(instr->shape())) { + if (!instr->shape().IsTuple()) { const llvm::Value* val = shape_tree.begin()->second; StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n"); continue; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index c0edae530cedba45c897b07b7b9cc72eaaab397c..f57b594e9c18078a3bbbf4d2b4db7e989c4edfdd 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -61,7 +62,7 @@ class HloToIrBindings { // Returns whether `hlo` is bound to an LLVM IR value. bool BoundToIrValue(const HloInstruction& hlo) const { - return base_ptrs_.count(&hlo); + return base_ptrs_.contains(&hlo); } llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } @@ -110,7 +111,8 @@ class HloToIrBindings { // For an instruction that generates multiple outputs, the root will be a // tuple shape. The IrArray for each element output is stored in the subnode // in the ShapeTree. - std::unordered_map> base_ptrs_; + absl::flat_hash_map> + base_ptrs_; // The address of the memory block that contains all temporary buffers. llvm::Value* temp_buffer_base_ = nullptr; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 8c3a026740851767855beae59d6a3c92f7a0d6bd..676380c3b10f9a20c641eea0d9a948a26becaddc 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -36,6 +36,21 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, ShapeTree infeed_buffers = GetOrCreateInfeedManager()->BlockingGetNextDestination(); + // infeed_slices_'s shape should be a tuple of shape (buffers, token). + const auto& infeed_shape = infeed_slices_.shape(); + TF_RET_CHECK(infeed_shape.IsTuple()) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK(infeed_shape.tuple_shapes().size() == 2) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK(infeed_shape.tuple_shapes(1).IsToken()) + << ShapeUtil::HumanStringWithLayout(infeed_shape); + TF_RET_CHECK( + ShapeUtil::Equal(infeed_buffers.shape(), infeed_shape.tuple_shapes(0))) + << "Expected infeed of shape " + << ShapeUtil::HumanStringWithLayout(infeed_shape.tuple_shapes(0)) + << " but was " + << ShapeUtil::HumanStringWithLayout(infeed_buffers.shape()); + { // The infeed buffer has an extra outer tuple with a token. Adjust the index // accordingly. @@ -45,7 +60,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, const Shape& shape = ShapeUtil::GetSubshape(infeed_buffers.shape(), ShapeIndexView(index, 1)); // For the leaf buffers of the tuple copy the elements directly. - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { const BufferAllocation::Slice& tuple_element_buffer = infeed_slices_.element(index); se::DeviceMemoryBase tuple_element_address = diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 6151dd8ff4c92bb81bd756c68cc9377633c8c9d5..f07141029cbf8b034b74548f6fca8f1628589f0c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -282,22 +282,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, int64 operand_index) { - const HloInstruction* producer = consumer->operand(operand_index); - // The IR emitter has limited support for non-loop fusions with multi output - // at present. - // TODO(tjoerg): Relax this constraint to allow for arbitraty kinds of fusion. - if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() != HloInstruction::FusionKind::kLoop) { - return false; - } - // Multi-output fusion requires instructions with compatible shapes. - if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) { - return false; - } - // TODO(tjoerg): Stop calling `ShouldFuse` to relax the criteria for - // multi-output fusion. In particular, do not check whether an instruction is - // expensive to duplicate, since this doesn't matter here. - return GpuInstructionFusion::ShouldFuse(consumer, operand_index); + return false; } HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 688604cd36e5a45debf855aacd29d05ecda92341..a05ab86cf77a134a1fc387d93cb482aa1ff5345b 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -506,202 +506,11 @@ TEST_F(InstructionFusionTest, MultiOutputFusion) { })") .ValueOrDie(); - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT( - fusion->fused_expression_root(), - op::Tuple(op::Add(op::Subtract(), op::Parameter()), op::Subtract())); -} - -TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) { - // tanh --> add --> tuple - // \---------------/ - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - tanh = f32[4,3]{1,0} tanh(p0) - add = f32[4,3]{1,0} add(tanh, p1) - ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(tanh, add) - })") - .ValueOrDie(); - - // TODO(tjoerg): Allow multi-output fusion for expensive operations like tanh. + // Multi-output fusion is disabled here and performed in the + // GpuMultiOutputFusion pass instead. ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()) - << module->ToString(); -} - -TEST_F(InstructionFusionTest, MultiOutputFusion2) { - // sub --> add1 --\--------\ - // \----------> add2 --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[4,3]{1,0} parameter(2) - sub = f32[4,3]{1,0} subtract(p0, p2) - add1 = f32[4,3]{1,0} add(sub, p1) - add2 = f32[4,3]{1,0} add(sub, add1) - ROOT tuple = (f32[4,3]{1,0}) tuple(add1, add2) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT(fusion->fused_expression_root(), - op::Tuple(op::Add(op::Subtract(), op::Add()), - op::Add(op::Subtract(), op::Parameter()))); -} - -TEST_F(InstructionFusionTest, MultiOutputFusion3) { - // sub --> add1 ----\--------\ - // \ --> add2 --> add3 --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[4,3]{1,0} parameter(2) - p3 = f32[4,3]{1,0} parameter(3) - sub = f32[4,3]{1,0} subtract(p0, p2) - add1 = f32[4,3]{1,0} add(sub, p1) - add2 = f32[4,3]{1,0} add(p2, sub) - add3 = f32[4,3]{1,0} add(add1, add2) - ROOT tuple = (f32[4,3]{1,0}) tuple(add3, add2) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - SCOPED_TRACE(module->ToString()); - - // Expect that there is one multi-output fusion and subtract has not been - // duplicated. - EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); - EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); - TF_ASSERT_OK_AND_ASSIGN( - const HloInstruction* fusion, - FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); - EXPECT_THAT(fusion->fused_expression_root(), - op::Tuple(op::Add(op::Add(), op::Add()), - op::Add(op::Parameter(), op::Subtract()))); -} - -TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) { - // sub --> mul ---\ - // \--> call --> add --> tuple - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - c = f32[] constant(42) - p0 = f32[4,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - sub = f32[4,3]{1,0} subtract(p0, p1) - mul = f32[4,3]{1,0} multiply(sub, c) - call = f32[4,3]{1,0} custom-call(sub), custom_call_target="foo" - add = f32[4,3]{1,0} add(mul, call) - ROOT tuple = (f32[4,3]{1,0}) tuple(add) - })") - .ValueOrDie(); - - ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); - // Visit instructions in post order to detect cycles. - // TODO(tjoerg): Add cycle detection to the HloVerifier. - class DummyVisitor : public DfsHloVisitorWithDefault { - public: - DummyVisitor() {} - Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { - return Status::OK(); - } - } visitor; - for (const HloComputation* computation : module->MakeComputationPostOrder()) { - // Accept will return a FailedPrecondition when a cycle is detected. - EXPECT_TRUE(computation->root_instruction()->Accept(&visitor).ok()); - } -} - -TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) { - // sub[2,3] --> add[4,3] --> tuple([2,3], [4,3]) - // \-------------------------/ - auto module = ParseHloString(R"( - HloModule test_module - ENTRY OutputFusion { - p0 = f32[2,3]{1,0} parameter(0) - p1 = f32[4,3]{1,0} parameter(1) - p2 = f32[2,3]{1,0} parameter(2) - sub = f32[2,3]{1,0} subtract(p0, p2) - add = f32[4,3]{1,0} add(sub, p1) - ROOT tuple = (f32[2,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) - })") - .ValueOrDie(); - - // Multi-output fusion requires shapes to be compatible. Since `sub` and `add` - // have incompatible shapes, expect that no multi-output fusion happens. - ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()) - << module->ToString(); -} - -TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { - auto module = ParseHloString(R"( - HloModule test_module - - add_computation { - add_lhs = f32[] parameter(0) - add_rhs = f32[] parameter(1) - ROOT add_root = f32[] add(add_lhs, add_rhs) - } - - fused_computation { - p1 = f32[10] parameter(0) - zero = f32[] constant(0) - ROOT f2_root = f32[] reduce(p1, zero), dimensions={0}, - to_apply=add_computation - } - - ENTRY entry { - p0 = f32[10] parameter(0) - mul = f32[10] multiply(p0, p0) - fusion = f32[] fusion(mul), kind=kInput, calls=fused_computation - ROOT tuple = (f32[10], f32[]) tuple(fusion, mul) - })") - .ValueOrDie(); - - // Multi-output fusion is not supported for non-loop fusions at present. Since - // `fused_computation` is a input fusion, expect no multi-output fusion to - // happen. - ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()) - << module->ToString(); + .ValueOrDie()); } TEST_F(InstructionFusionTest, FuseScalarConstant) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 42fb38dffae31b0f4322216545027e067cab250d..82bdd677d96d3d0826bb4127b32d074eb632b1a3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -40,7 +40,7 @@ namespace { // Return whether the given shape is rank 2 excluding the batch dimensions. bool IsRank2(const Shape& shape, int64 batch_dimensions_size) { - return ShapeUtil::Rank(shape) == batch_dimensions_size + 2; + return shape.rank() == batch_dimensions_size + 2; } // In a gemm operation where output = lhs * rhs, check whether the given shapes @@ -54,7 +54,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, PrimitiveType output_primitive_type = output_shape.element_type(); bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || - output_primitive_type == F64 || output_primitive_type == C64); + output_primitive_type == F64 || output_primitive_type == C64 || + output_primitive_type == C128); return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) && IsRank2(rhs_shape, batch_dimensions_size) && IsRank2(output_shape, batch_dimensions_size) && @@ -154,20 +155,17 @@ bool IsReductionToVector(const HloInstruction& reduce) { const HloInstruction* input = reduce.operand(0); std::vector dims_to_keep; for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) { - if (!std::count(reduce.dimensions().begin(), reduce.dimensions().end(), - dim)) { + if (!absl::c_linear_search(reduce.dimensions(), dim)) { dims_to_keep.push_back(dim); } } return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), dims_to_keep) && - ShapeUtil::Equal(reduce.shape(), ShapeUtil::FilterDimensions( - [&dims_to_keep](int64 dim) { - return std::count( - dims_to_keep.begin(), - dims_to_keep.end(), dim); - }, - input->shape())); + ShapeUtil::Equal( + reduce.shape(), + ShapeUtil::FilterDimensions( + [&](int64 dim) { return absl::c_count(dims_to_keep, dim); }, + input->shape())); } // This emits a device-side call to @@ -268,5 +266,17 @@ string CudnnConvKindToString(CudnnConvKind kind) { } } +llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { + return b->CreateAnd( + b->CreateICmpEQ( + b->getInt32(0), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b)), + b->CreateICmpEQ( + b->getInt32(0), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b))); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index f373d4a8393a047aba599b0fae954e98a740161e..ebf4d926b7a280e10b09a2532caba7ad6ab3ceb2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -155,6 +155,10 @@ llvm::Value* EmitPrintf(absl::string_view fmt, llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); +// Emits code that determines whether the current thread is thread 0 within +// block 0 of the kernel. +llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 31591914cc553f0f5ecd81cb514faa1dc56ea041..0007a9a8a3369d8ac010640127e1561615a6d813 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -63,9 +63,6 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, &ir_emitter_context->buffer_assignment(), &b_, module_, is_nested), hlo_module_config_(hlo_module_config) { - b_.setFastMathFlags(llvm_ir::GetFastMathFlags( - /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_gpu_enable_fast_math())); } Status IrEmitter::DefaultAction(HloInstruction* hlo) { @@ -433,7 +430,7 @@ Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { auto on_false = tuple_select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); - TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + TF_RET_CHECK(tuple_select->shape().IsTuple()); llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select), GetIrArray(*pred, *tuple_select), GetBasePointer(*on_true), GetBasePointer(*on_false), @@ -640,9 +637,9 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { return Unimplemented("Hit a case for fft that is not implemented on GPU."); } -Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { +Status IrEmitter::HandleAllReduce(HloInstruction* crs) { // TODO(b/33011107): Support cross replica sum on GPU. - return Unimplemented("CrossReplicaSum is not implemented on GPU."); + return Unimplemented("AllReduce is not implemented on GPU."); } Status IrEmitter::HandleParameter(HloInstruction* parameter) { @@ -651,7 +648,7 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { Status IrEmitter::HandleReduce(HloInstruction* reduce) { // TODO(b/112040122): Support variadic reduce. - if (!ShapeUtil::IsArray(reduce->shape())) { + if (!reduce->shape().IsArray()) { return Unimplemented("Variadic reduce is not supported on GPU"); } auto arg = reduce->operand(0); @@ -786,7 +783,7 @@ StatusOr IrEmitter::ComputeNestedElement( std::vector IrEmitter::ConstructIrArrayForOutputs( const HloInstruction& hlo) { std::vector output_arrays; - if (ShapeUtil::IsTuple(hlo.shape())) { + if (hlo.shape().IsTuple()) { int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); output_arrays.reserve(num_outputs); for (int64 i = 0; i < num_outputs; ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 2da46c016935d0e927879bbfb0d05cfc4899d818..f380aee9d3c06a29b503c81c7bd3846dbccf6ce5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -81,7 +81,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index bbe1583c01167b3fbb50e066ad59a48e45f5e683..294a454931b5cfa368bf094c428a1e942f4556b8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -88,6 +89,9 @@ namespace xla { namespace gpu { using llvm_ir::KernelMappingScheme; +using EmitElementFunction = + std::function; namespace { @@ -292,13 +296,12 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, auto shape_in_range = [&](const Shape& s) { bool in_range = true; - ShapeUtil::ForEachSubshape( - s, [&](const Shape& sub_shape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(sub_shape) && - !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { - in_range = false; - } - }); + ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape, + const ShapeIndex& /*index*/) { + if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { + in_range = false; + } + }); return in_range; }; @@ -542,8 +545,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // HandleFusion specializes reduction from a multi-dimensional array to // a 1D array. The specialized version requires a initializer thunk that // initializes the output array to the initial value of the reduce. - if (root->opcode() == HloOpcode::kReduce && - ShapeUtil::IsTuple(root->shape())) { + if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) { // TODO(b/112040122): Support variadic reduce. return Unimplemented("Variadic reduce is not supported on GPU"); } @@ -634,7 +636,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { // TODO(b/112040122): Support multi-output reduce. - if (!ShapeUtil::IsArray(reduce->shape())) { + if (!reduce->shape().IsArray()) { return Unimplemented("Multi-output reduce is not supported on GPU"); } if (IsReductionToVector(*reduce)) { @@ -698,8 +700,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( const auto* source = select_and_scatter->operand(1); const Window& window = select_and_scatter->window(); PrimitiveType operand_element_type = operand->shape().element_type(); - const int64 rank = ShapeUtil::Rank(operand->shape()); - CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); + const int64 rank = operand->shape().rank(); + CHECK_EQ(rank, source->shape().rank()); CHECK_EQ(rank, window.dimensions_size()); TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, @@ -1015,7 +1017,7 @@ Status IrEmitterUnnested::EmitScatter( int64 raw_window_multidim_idx = 0; std::vector input_window_multidim; std::vector input_window_bounds; - for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) { + for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) { if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { input_window_bounds.push_back(1); // Trivial dimension. input_window_multidim.push_back(index.GetConstantWithIndexType(0)); @@ -1027,12 +1029,11 @@ Status IrEmitterUnnested::EmitScatter( ++raw_window_multidim_idx; } } - DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape())); + DCHECK_EQ(input_window_multidim.size(), operand->shape().rank()); // Insert a 1 dimension at the end if index_vector_dim requests one. Shape scatter_indices_shape = scatter_indices->shape(); - if (dim_numbers.index_vector_dim() == - ShapeUtil::Rank(scatter_indices_shape)) { + if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) { scatter_indices_shape.add_dimensions(1); scatter_indices_shape.mutable_layout()->add_minor_to_major( dim_numbers.index_vector_dim()); @@ -1295,11 +1296,11 @@ Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { return IrEmitter::HandleTupleSelect(tuple_select); } -Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { +Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { if (hlo_module_config_.replica_count() != 1) { // TODO(b/33011107): Support nontrivial cross replica sum on GPU. return Unimplemented( - "CrossReplicaSum with >1 replica is not implemented on GPU."); + "AllReduce with >1 replica is not implemented on GPU."); } // CRS with one operand and one replica is simply the identity function. @@ -1310,8 +1311,8 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { // HloModuleConfig::num_replicas changes between when the module is compiled // and when it's run. if (crs->operand_count() == 1) { - CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) - << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + CHECK(crs->operand(0)->shape().IsArray()) + << "Operands to all-reduce must be arrays: " << crs->ToString(); AddThunkToThunkSequence(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(0)), /*destination_buffer=*/GetAllocationSlice(*crs), @@ -1509,10 +1510,10 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( return !allocation->is_constant(); }); - std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), - [](const BufferAllocation* a, const BufferAllocation* b) { - return a->index() < b->index(); - }); + absl::c_sort(non_constant_buffers, + [](const BufferAllocation* a, const BufferAllocation* b) { + return a->index() < b->index(); + }); llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); @@ -2059,8 +2060,16 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_)); } - // For multioutput fusion, we need to emit each operand and the root. + // Emit the tuple pointers in one thread. We could do this at any point in + // the kernel, but we do it at the beginning in the hopes of reducing register + // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the + // kernel *anyway*. std::vector output_arrays = ConstructIrArrayForOutputs(hlo); + KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + }); + + // For multioutput fusion, we need to emit each operand and the root. TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &b_, unroll_factor) @@ -2069,17 +2078,39 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( &hlo, launch_dimensions.launch_bound(), &b_))); b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); - return Status::OK(); } +namespace { + +// Returns true if the fusion contains any instruction that is likely +// translated to complex LLVM IR, such as loops, and prevent vectorization. +bool MayPreventVectorization(const HloInstruction& fusion_hlo) { + CHECK_EQ(fusion_hlo.opcode(), HloOpcode::kFusion); + return absl::c_any_of( + fusion_hlo.fused_instructions_computation()->instructions(), + [&](const HloInstruction* instr) { + switch (instr->opcode()) { + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSort: + case HloOpcode::kDot: + return true; + default: + return false; + } + }); +} + +} // namespace + Status IrEmitterUnnested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { int unroll_factor = 1; // Unfused elementwise operations are usually memory bound, unroll them. - if (hlo.IsElementwise() || hlo.opcode() == HloOpcode::kFusion) { + if (hlo.IsElementwise() || + (hlo.opcode() == HloOpcode::kFusion && !MayPreventVectorization(hlo))) { unroll_factor = ComputeMaxUnrollFactor(&hlo); } @@ -2130,84 +2161,88 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( namespace { -// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the -// thread lives within a square tile of size tile_size (so thread blocks are of -// size tile_size * tile_size). -std::tuple CalculateYXCoordinateWithinTile( - llvm::IRBuilder<>* builder, llvm::Value* tile_size, - int64 threads_per_tile) { - // Calculate the starting element coordinate within a tile for the current - // thread, (y, x) from thread_id. - llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, threads_per_tile, - llvm::cast(thread_id)); - thread_id = builder->CreateIntCast(thread_id, tile_size->getType(), - /*isSigned=*/true, "thread.id.x"); - auto x = builder->CreateURem(thread_id, tile_size); - auto y = builder->CreateUDiv(thread_id, tile_size); - return std::make_tuple(y, x); -} - -// Reads block_idx.x, casts it to type index_ty, and adds the assumption that -// it's in the range [0, num_blocks]. -llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty, - int64 num_blocks) { - llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, num_blocks, - llvm::cast(block_id)); - return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true, - "block.id.x"); -} - -void EmitFullTile(const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, - llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, - llvm::Type* index_ty, - const std::function& emit_elem_function) { +std::tuple GetStartOffsetAndStepForX( + int64 tile_size_x, int64 num_threads_x, + const KernelMappingScheme* mapping_scheme, llvm::IRBuilder<>* builder, + llvm::Value* x, llvm::Type* index_ty) { + llvm::Value* start_offset_x; + int64 step_x; + if (mapping_scheme->DilatedX()) { + start_offset_x = x; + step_x = num_threads_x; + } else { + start_offset_x = builder->CreateMul( + x, llvm::ConstantInt::get(index_ty, tile_size_x / num_threads_x)); + step_x = 1; + } + return std::make_tuple(start_offset_x, step_x); +} + +void EmitFullElementalTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + const string& loop_name, KernelSupportLibrary* ksl, + llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Type* index_ty, + const EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); - for (int64 i = 0; i < tile_size_y; i += num_threads_y) { - IrArray::Index source_idx_y = - tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, i), - KernelMappingScheme::DimY, builder); - llvm::Value* y_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, i), y); - for (int64 j = 0; j < tile_size_x; j += num_threads_x) { - IrArray::Index source_idx = - source_idx_y.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), - KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - emit_elem_function(source_idx, y_loc, x_loc); - } - } -} -void EmitPartialTile( - const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, const string& loop_name, - KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, - llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, - llvm::Type* index_ty, - const std::function& emit_elem_function) { + llvm::Value* start_offset_x; + int64 step_x; + std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX( + tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty); + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder) + .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder); + ksl->For(loop_name + "_y", /*start=*/llvm::ConstantInt::get(index_ty, 0), + /*end=*/llvm::ConstantInt::get(index_ty, tile_size_y), + /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), + [&](llvm::Value* y_indvar) { + IrArray::Index source_idx_y = source_idx.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder); + llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); + + for (int64 j = 0; j < tile_size_x / num_threads_x; j++) { + IrArray::Index source_idx_y_x = source_idx_y.AddOffsetToDim( + llvm::ConstantInt::get(index_ty, j * step_x), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = builder->CreateAdd( + llvm::ConstantInt::get(index_ty, j * step_x), + start_offset_x); + emit_elem_function(source_idx_y_x, y_loc, x_loc, j); + } + }); +} + +void EmitPartialElementalTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + const string& loop_name, + KernelSupportLibrary* ksl, + llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, + llvm::Value* tile_width, llvm::Type* index_ty, + const EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); - for (int64 j = 0; j < tile_size_x; j += num_threads_x) { - IrArray::Index source_idx = - tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), - KernelMappingScheme::DimX, builder); - llvm::Value* x_loc = - builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); - - ksl->IfReturnVoid( + llvm::Value* start_offset_x; + int64 step_x; + std::tie(start_offset_x, step_x) = GetStartOffsetAndStepForX( + tile_size_x, num_threads_x, mapping_scheme, builder, x, index_ty); + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(y, KernelMappingScheme::DimY, builder) + .AddOffsetToDim(start_offset_x, KernelMappingScheme::DimX, builder); + for (int64 j = 0; j < tile_size_x / num_threads_x; j++) { + IrArray::Index source_idx_x = + source_idx.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j * step_x), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = builder->CreateAdd( + llvm::ConstantInt::get(index_ty, j * step_x), start_offset_x); + + ksl->If( loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width), [&] { // tile_height_bound = @@ -2219,20 +2254,19 @@ void EmitPartialTile( llvm::Value* tile_height_bound = builder->CreateMul( ceiling_of_ratio, llvm::ConstantInt::get(index_ty, num_threads_y)); - ksl->ForReturnVoid( + ksl->For( loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0), /*end=*/tile_height_bound, /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), [&](llvm::Value* y_indvar) { llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); - ksl->IfReturnVoid( - loop_name + "_y_in_tile", - builder->CreateICmpULT(y_loc, tile_height), [&] { - emit_elem_function( - source_idx.AddOffsetToDim( - y_indvar, KernelMappingScheme::DimY, builder), - y_loc, x_loc); - }); + ksl->If(loop_name + "_y_in_tile", + builder->CreateICmpULT(y_loc, tile_height), [&] { + emit_elem_function( + source_idx_x.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder), + y_loc, x_loc, j); + }); }); }); } @@ -2251,13 +2285,12 @@ void EmitTiledElementalCodeWithBoundsCheck( const IrArray::Index& tile_origin_index, const string& loop_name, KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, - const std::function& emit_elem_function) { + const EmitElementFunction& emit_elem_function) { int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); - ksl->IfReturnVoid( + ksl->If( loop_name + "_full_tile", builder->CreateAnd( builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x), @@ -2265,13 +2298,13 @@ void EmitTiledElementalCodeWithBoundsCheck( builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y), tile_height)), [&] { - EmitFullTile(mapping_scheme, tile_origin_index, builder, y, x, index_ty, - emit_elem_function); + EmitFullElementalTile(mapping_scheme, tile_origin_index, loop_name, ksl, + builder, y, x, index_ty, emit_elem_function); }, [&] { - EmitPartialTile(mapping_scheme, tile_origin_index, loop_name, ksl, - builder, y, x, tile_height, tile_width, index_ty, - emit_elem_function); + EmitPartialElementalTile(mapping_scheme, tile_origin_index, loop_name, + ksl, builder, y, x, tile_height, tile_width, + index_ty, emit_elem_function); }); } } // namespace @@ -2288,7 +2321,7 @@ void EmitTiledElementalCodeWithBoundsCheck( void IrEmitterUnnested::EmitTileElementForCopy( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 /*x_iter_num*/) { llvm_ir::TiledParameterInfo* tiled_param_info = kernel_info->GetTiledParameterInfo(); // TODO(jlebar): Add AA metadata to this load. @@ -2318,7 +2351,7 @@ void IrEmitterUnnested::EmitTileElementForCopy( void IrEmitterUnnested::EmitTileElementForFusion( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 /*x_iter_num*/) { llvm_ir::TiledParameterInfo* tiled_param_info = kernel_info->GetTiledParameterInfo(); std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); @@ -2381,14 +2414,14 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { AddressVector* GetMutablePartialResultAddresses() { return &partial_result_addresses_; } - const AddressVector& GetPartialResultAddresses() const { + absl::Span GetPartialResultAddresses() const { return partial_result_addresses_; } AddressVector* GetMutableReductionInputAddresses() { return &reduction_input_addresses_; } - const AddressVector& GetReductionInputAddresses() const { + absl::Span GetReductionInputAddresses() const { return reduction_input_addresses_; } @@ -2401,7 +2434,7 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { InlinedVector* GetMutableReductionOutputShapeIndices() { return &reduction_output_shape_indices_; } - const InlinedVector& GetReductionOutputShapeIndices() const { + absl::Span GetReductionOutputShapeIndices() const { return reduction_output_shape_indices_; } @@ -2419,6 +2452,23 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { : llvm_ir::KernelMappingScheme::DimX; } + int GetNumberOfPartialResults() const { + if (IsRowReduction()) { + return 1; + } + int64 num_thread = mapping_scheme_->GetNumberOfThreadsForDimensionX(); + int64 tile_size = mapping_scheme_->GetTileSizeForDimensionX(); + CHECK_EQ(tile_size % num_thread, 0); + return tile_size / num_thread; + } + + int GetPartialResultIndex(int64 x_iter_num) const { + if (IsRowReduction()) { + return 0; + } + return x_iter_num; + } + private: AddressVector partial_result_addresses_; AddressVector reduction_input_addresses_; @@ -2478,10 +2528,11 @@ void IrEmitterUnnested::EmitPrologueForOneReduction( llvm::AllocaInst* reduction_input_address = Alloca(element_type); reduction_input_addresses->push_back(reduction_input_address); + int num_partial_results = reduction_info->GetNumberOfPartialResults(); AddressVector* partial_result_addresses = reduction_info->GetMutablePartialResultAddresses(); llvm::AllocaInst* partial_result_address = - Alloca(element_type, /*ArraySize=*/nullptr, + Alloca(element_type, /*ArraySize=*/b_.getInt32(num_partial_results), "partial_reduction_result." + llvm::Twine(reduce_idx)); partial_result_addresses->push_back(partial_result_address); @@ -2504,7 +2555,9 @@ void IrEmitterUnnested::EmitPrologueForOneReduction( .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_); } - Store(init_ir_value, partial_result_address); + for (int i = 0; i < num_partial_results; ++i) { + Store(init_ir_value, InBoundsGEP(partial_result_address, {b_.getInt32(i)})); + } } void IrEmitterUnnested::EmitPrologueForReduction( @@ -2542,10 +2595,14 @@ void IrEmitterUnnested::EmitPrologueForReduction( std::move(output_shape_index)); } - // Allocate stack storage to store the current output linear index and record - // the address of the storage. + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + + // Allocate stack storage to store the linear indices for the current output, + // and record the address of the storage. reduction_info->SetCurrentOutputLinearIndexAddress( - Alloca(reduction_info->GetIndexType())); + Alloca(reduction_info->GetIndexType(), + /*ArraySize=*/b_.getInt32(num_partial_results), + "current_output_linear_index_address")); if (!reduction_info->IsRowReduction()) { llvm::Type* bool_ty = b_.getInt1Ty(); @@ -2556,8 +2613,8 @@ void IrEmitterUnnested::EmitPrologueForReduction( } void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces( - const InlinedVector& reducers, - const AddressVector& partial_result_addresses) { + absl::Span reducers, + absl::Span partial_result_addresses) { for (int distance = 16; distance >= 1; distance /= 2) { for (int i = 0; i != reducers.size(); ++i) { llvm::Type* element_type = @@ -2589,11 +2646,11 @@ void IrEmitterUnnested::EmitEpilogueForReduction( ReductionCodegenInfo* reduction_info = static_cast(kernel_info); int num_reduces = reduction_info->GetNumberOfReduces(); - const AddressVector& partial_result_addresses = + absl::Span partial_result_addresses = reduction_info->GetPartialResultAddresses(); const InlinedVector& reducers = reduction_info->GetReducers(); - const InlinedVector& reduction_output_shape_indices = + absl::Span reduction_output_shape_indices = reduction_info->GetReductionOutputShapeIndices(); if (reduction_info->IsRowReduction()) { @@ -2615,36 +2672,45 @@ void IrEmitterUnnested::EmitEpilogueForReduction( llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_); } + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + // Emit an atomic operation that accumulates the partial reduction to the // output element. For row reduction, this is only for lane 0 due to the // if-statement emitted above. for (int i = 0; i != num_reduces; ++i) { - IrArray::Index element_index( - /*linear=*/Load(reduction_info->GetCurrentOutputLinearIndexAddress(), - "output_linear_addr"), - ShapeUtil::GetSubshape(unnested_hlo->shape(), - reduction_output_shape_indices[i]), - &b_); - llvm::Value* output_address = - GetIrArray(*unnested_hlo, *unnested_hlo, - reduction_output_shape_indices[i]) - .EmitArrayElementAddress(element_index, &b_, - "output_element_address"); - // Do not emit atomic operations if each element in the reduction result is - // computed by one block, that is the dimension being reduced has only one - // block. - const llvm_ir::KernelMappingScheme* mapping_scheme = - reduction_info->GetKernelMappingScheme(); - if (mapping_scheme->GetTileBlockSizeForDimension( - llvm_ir::KernelMappingScheme::DimZ) == 1 && - mapping_scheme->GetTileBlockSizeForDimension( - reduction_info->GetReducedDimensionEnum()) == 1) { - TF_CHECK_OK(EmitCallToNestedComputation( - *reducers[i], {output_address, partial_result_addresses[i]}, - output_address)); - } else { - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_result_addresses[i])); + for (int j = 0; j < num_partial_results; ++j) { + IrArray::Index element_index( + /*linear=*/Load( + InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), + {b_.getInt32(j)}), + "output_linear_addr"), + ShapeUtil::GetSubshape(unnested_hlo->shape(), + reduction_output_shape_indices[i]), + &b_); + llvm::Value* output_address = + GetIrArray(*unnested_hlo, *unnested_hlo, + reduction_output_shape_indices[i]) + .EmitArrayElementAddress(element_index, &b_, + "output_element_address"); + // Do not emit atomic operations if each element in the reduction result + // is computed by one block, that is the dimension being reduced has only + // one block. + const llvm_ir::KernelMappingScheme* mapping_scheme = + reduction_info->GetKernelMappingScheme(); + if (mapping_scheme->GetTileBlockSizeForDimension( + llvm_ir::KernelMappingScheme::DimZ) == 1 && + mapping_scheme->GetTileBlockSizeForDimension( + reduction_info->GetReducedDimensionEnum()) == 1) { + TF_CHECK_OK(EmitCallToNestedComputation( + *reducers[i], + {output_address, + InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)})}, + output_address)); + } else { + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + InBoundsGEP(partial_result_addresses[i], {b_.getInt32(j)}))); + } } } } @@ -2652,7 +2718,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction( void IrEmitterUnnested::EmitTileElementForReduction( HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 x_iter_num) { VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion ? unnested_hlo->fused_expression_root() @@ -2665,8 +2731,11 @@ void IrEmitterUnnested::EmitTileElementForReduction( // Record the linear address for the current reduction. const ReductionCodegenInfo* reduction_info = dynamic_cast(kernel_info); + int partial_result_index = reduction_info->IsRowReduction() ? 0 : x_iter_num; + Store(index[reduction_info->GetKeptDimensionEnum()], - reduction_info->GetCurrentOutputLinearIndexAddress()); + InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), + {b_.getInt32(partial_result_index)})); if (!reduction_info->IsRowReduction()) { llvm::Type* bool_ty = b_.getInt1Ty(); llvm::AllocaInst* output_inbound_addr = @@ -2713,9 +2782,16 @@ void IrEmitterUnnested::EmitTileElementForReduction( reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( index, GetFirstReduceInstruction(output_instructions)->operand(0)->shape()); - const AddressVector& partial_reduction_result_addresses = + int num_partial_results = reduction_info->GetNumberOfPartialResults(); + if (num_partial_results > 1) { + // Clear the linear index field of the IrArray::Index to enable the use of + // GetElementPointer with array types. This enables the vectorization of + // the computation for different partial results. + input_index.ClearLinearIndex(); + } + absl::Span partial_reduction_result_addresses = reduction_info->GetPartialResultAddresses(); - const AddressVector& reduction_input_addresses = + absl::Span reduction_input_addresses = reduction_info->GetReductionInputAddresses(); const InlinedVector& reducers = reduction_info->GetReducers(); @@ -2725,10 +2801,12 @@ void IrEmitterUnnested::EmitTileElementForReduction( for (int i = 0; i != reducers.size(); ++i) { llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie(); Store(input_ir_value, reduction_input_addresses[i]); + llvm::Value* partial_result_address = + InBoundsGEP(partial_reduction_result_addresses[i], + {b_.getInt32(partial_result_index)}); TF_CHECK_OK(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], reduction_input_addresses[i]}, - partial_reduction_result_addresses[i])); + *reducers[i], {partial_result_address, reduction_input_addresses[i]}, + partial_result_address)); } // Emit code to generate the output for the non-reduction instructions in the @@ -2739,8 +2817,8 @@ void IrEmitterUnnested::EmitTileElementForReduction( // Emits a kernel for the hlo instruction using the given tiling scheme. void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, - const KernelCodegenInfo* kernel_info, - KernelSupportLibrary& ksl, + KernelCodegenInfo* kernel_info, + KernelSupportLibrary* ksl, llvm::Type* index_ty) { KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); absl::Span dims_in_tile = mapping_scheme->GetDimensionsInTiles(); @@ -2773,16 +2851,14 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, llvm::Value* num_tiles_in_block = Select(ICmpEQ(last_block_for_dim, block_id_for_dim), last_block_size_for_dim, block_size_for_dim); - - ksl.ForReturnVoid( - loop_name, - /*start=*/index_typed_constant(0), - /*end=*/num_tiles_in_block, - /*step=*/1, [&](llvm::Value* block_dim_induction_var) { - IrArray::Index tile_index = starting_tile.AddOffsetToDim( - block_dim_induction_var, dim_id, &b_); - emit_next_block_dim(tile_index); - }); + ksl->For(loop_name, + /*start=*/index_typed_constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_id, &b_); + emit_next_block_dim(tile_index); + }); } }; @@ -2837,7 +2913,8 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, // unnested_hlo: The unnested hlo instruction for which the kernel is generated. // Currently, these hlo instructions are supported: kLoop fusion, kCopy. // tiled_param_ids: The IDs for the parameters that are 0-2-1 transpose of -// other tensors with the same dimensions and need to be tiled and tranposed. +// other tensors with the same dimensions and are safe to be tranposed via +// the shared memory tranpose implementation. // mapping_scheme: The tiling scheme to use. // kernel_generator: Contains function objects for code generation, such as // element generator, block prologue and epilogue generators. @@ -2864,14 +2941,40 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( << llvm_ir::DumpToString(*param_shmem_buffers[id]); } - LaunchDimensions launch_dimensions = LaunchDimensions( - mapping_scheme->GetNumberOfBlocks(), mapping_scheme->GetThreadsPerTile()); - llvm::Type* index_ty = GetIndexTypeForKernel( - unnested_hlo, launch_dimensions.launch_bound(), &b_); + const ReductionCodegenInfo* reduction_info = + dynamic_cast(kernel_info); + bool is_column_reduction = + (reduction_info && !reduction_info->IsRowReduction()); + + LaunchDimensions launch_dimensions = + LaunchDimensions(mapping_scheme->GetNumberOfBlocks(), + mapping_scheme->GetThreadsPerBlock()); + + // TODO(b/110211620): Enable int32 index type for column reduction. + llvm::Type* index_ty = + is_column_reduction + ? b_.getInt64Ty() + : GetIndexTypeForKernel(unnested_hlo, + launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; + // For multioutput fusion, one thread needs to output a tuple with pointers to + // all the individual outputs. We could do this at any point in the kernel, + // but we do it at the beginning in the hopes of reducing register pressure, + // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel + // *anyway*. + if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) { + KernelSupportLibrary{&b_}.If( + "emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), + ConstructIrArrayForOutputs(*unnested_hlo), &b_, + module_); + }); + } + // For each tiled parameter, cast its input IrArray to the corresponding // reduced shape and keep the reduced shape live during IR emission. std::vector param_in_reduced_shape_arrays; @@ -2899,8 +3002,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( auto emit_tiled_elemental_code_with_bounds_check = [&](const IrArray::Index& index, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, - const std::function& emit_elem_function) { + const EmitElementFunction& emit_elem_function) { EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name, &ksl, &b_, y, x, tile_height, tile_width, emit_elem_function); @@ -2913,10 +3015,6 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( const IrArray::Index input_tile_origin( Permute({0, 2, 1}, output_tile_origin.multidim())); - const IrArray::Index input_index = - input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) - .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // If shared memory transpose is needed, wait for all threads to reach this // point, lest we copy a value from tile to output before the other thread // copies it from input to tile. This is `__syncthreads` in CUDA. @@ -2926,9 +3024,10 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // Note that tile_width and tile_height are flipped here because we are // reading a transposed tile. emit_tiled_elemental_code_with_bounds_check( - input_index, "input", output_tile_bounds[2], output_tile_bounds[1], + input_tile_origin, "input", output_tile_bounds[2], + output_tile_bounds[1], [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { + llvm::Value* x_loc, int64 /*x_iter_num*/) { for (int64 id : tiled_param_ids) { IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; @@ -2948,18 +3047,15 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); kernel_info->SetTiledParamInfo(&tiled_param_info); - const IrArray::Index output_index = - output_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) - .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // Write to output[index] by emitting code like normal, except that values // for the tiled parameters are read from the shmem buffers. emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[1], output_tile_bounds[2], - [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - kernel_generator.GetTileElementGenerator()(unnested_hlo, index, - kernel_info, y_loc, x_loc); + output_tile_origin, "output", output_tile_bounds[1], + output_tile_bounds[2], + [&](const IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num) { + kernel_generator.GetTileElementGenerator()( + unnested_hlo, index, kernel_info, y_loc, x_loc, x_iter_num); }); // If a tile block contains multiple tiles and shared memory buffers are @@ -2977,7 +3073,7 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( block_prologue_generator(unnested_hlo, kernel_info); } - EmitBlock(std::move(emit_one_tile), kernel_info, ksl, index_ty); + EmitBlock(std::move(emit_one_tile), kernel_info, &ksl, index_ty); const BlockEpilogueGenerator& block_epilogue_generator = kernel_generator.GetBlockEpilogueGenerator(); @@ -2985,21 +3081,15 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( block_epilogue_generator(unnested_hlo, kernel_info); } - // For multioutput fusion, emit a tuple with pointers to all the individual - // outputs. - if (unnested_hlo->IsMultiOutputFusion()) { - std::vector output_arrays = - ConstructIrArrayForOutputs(*unnested_hlo); - llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), output_arrays, - &b_, module_); - } - return launch_dimensions; } // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose // algorithm to improve the memory access patterns for the input parameters -// with a shape that is a 0-2-1 transpose of the output tensor shape. +// with a shape that is a 0-2-1 transpose of the output tensor shape. The caller +// is responsible for making sure that it is safe to apply the shared memory +// tranpose on the input parameters. +// // // For the purpose of tiling, the output tensors have a logical shape of three // components 0-2-1 while the relevant input parameters have a logical shape @@ -3032,17 +3122,19 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( element_generator = [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc) { - EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num) { + EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc, x_iter_num); }; } else { DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion); - element_generator = [&](HloInstruction* hlo, - const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc) { - EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc); - }; + element_generator = + [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc, int64 x_iter_num) { + EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc, + x_iter_num); + }; } KernelCodegenInfo kernel_info(&mapping_scheme); KernelCodeGenerator kernel_generator(std::move(element_generator)); @@ -3050,26 +3142,99 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( } namespace { -// Returns true to indicate it is safe to use the tile based shared memory -// transpose implementation to implement the kernel for the instruction. +// A recursive function to inspect the users of a parameter to determine +// whether it's safe for a parameter to participate in a shared-memory +// transpose. // -// An instruction is not safe for such an implementation if it can change the -// element order of a tensor without changing the dimension of the tensor, and -// the instruction has a corresponding elemental_ir_emitter. -bool IsInstructionSafeForTileBasedTranspose(const HloInstruction* hlo) { - auto is_safe_for_tile_based_transpose = [&](const HloInstruction* instr) { - HloOpcode opcode = instr->opcode(); - CHECK_NE(opcode, HloOpcode::kFusion); - return (opcode != HloOpcode::kReverse && opcode != HloOpcode::kGather); - }; +// Consider a fusion parameter P for which we might want to use a shmem +// transpose. If we do, we use a GPU thread block to preload a tile of P with +// indices [z, y..y+31, x..x+31] to compute an output tile with the same indices +// cooperatively, where z, y, x are the indices for the normalized input/output +// tensor (see the document for FindTranspose021 for the definition of +// normalized tensor for 0-2-1 transpose). This shmem transpose implementation +// requires that the computation of the output tile only read elements within +// the preload tile. If this is not true, we can't use a shmem transpose for P. +// +// If the computation of output element [z, y, x] only requires the element of +// P with the same indices, the shmem tranpose implementation can be applied +// to P safely. This is a sufficient but not necessary condition. We check all +// the transitive users of P to see if we can find a user that may cause an +// exception to the situation. If such a user is not found, we conclude that P +// is safe for shmem transpose. +// +// This is trivially true for elementwise operations and some "data-movement" +// ops like kTuple. However, it's not true for operations that can change the +// dimensions of the inputs (e.g. pad, slice) and bitcast operation. +// For example: +// +// fused_computation { +// param_0 = f32[64,64]{1,0} parameter(0) +// ROOT bitcast = f32[64,64]{0,1} bitcast(param_0) +// } +// The output element at logical address [0, 63] depends on the input element +// at logical address [63, 0], which would not be within the shared-memory +// block. +// +// TODO(bixia): In order to extend this for kInput fusion, that is reduction +// with tranpose, we only need to end the use-chain checking with the input of +// a reduce operations. In this case, the above description on "output" apply +// to the result of such a use-chain, which provides the input to the reduce +// operation. +bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) { + if (hlo->IsElementwise()) { + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return IsInstructionSafeForShmemTranspose(user); + }); + } + + switch (hlo->opcode()) { + // Non-elementwise instructions that don't cause the shmem transpose + // to be unsafe, including the instructions that don't currently fuse. + case HloOpcode::kGetDimensionSize: + // The result of the operation doesn't rely on the content of the + // tensor. As such, there is no need to further inspect its users. + return true; + case HloOpcode::kGetTupleElement: + case HloOpcode::kMap: + case HloOpcode::kParameter: + case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return IsInstructionSafeForShmemTranspose(user); + }); - if (hlo->opcode() == HloOpcode::kFusion) { - return absl::c_all_of(hlo->fused_instructions_computation()->instructions(), - is_safe_for_tile_based_transpose); + default: + return false; } +} - return is_safe_for_tile_based_transpose(hlo); +// Given a group of input parameters that are 0-2-1 tranpose of the outputs of +// a fusion kernel, returns the input parameters that are safe for the shared +// memory tranpose implementation. +// +// When a tile based shared memory transpose is used to implement an input with +// 0-2-1 transpose, we preload a tile of the input elements +// [z, y..y+31, x..x+31] to compute the output tile elements of the same +// indices. Preloading the input tile this way is only safe when the computation +// of the output tile elements do not need any input element outside the +// preloaded tile. We inspect all the transitive users of the input parameter +// up to the fusion root instruction to see if we can find any instruction +// that can make preloading the input tile unsafe. +std::vector FilterInputsForShmemTranspose(const HloInstruction* fusion, + std::vector input_ids) { + std::vector filtered_input_ids; + for (int64 i = 0; i < input_ids.size(); ++i) { + const HloInstruction* input = fusion->fused_parameter(input_ids[i]); + if (IsInstructionSafeForShmemTranspose(input)) { + filtered_input_ids.push_back(input_ids[i]); + } else { + VLOG(10) << "Input not safe for shmem transpose " << input->ToString() + << "\n"; + } + } + return filtered_input_ids; } + } // namespace bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { @@ -3116,8 +3281,11 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return false; } - if (!IsInstructionSafeForTileBasedTranspose(hlo)) { - return false; + if (opcode == HloOpcode::kFusion) { + params_012 = FilterInputsForShmemTranspose(hlo, params_012); + if (params_012.empty()) { + return false; + } } // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the @@ -3198,7 +3366,7 @@ Status AreFusedReductionOutputsConsistent( // dimensions from minor to major. DimensionVector GetDimensionsToKeepMinorToMajor( const Shape& input_shape, absl::Span dims_to_reduce) { - DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + DimensionVector input_dims(input_shape.rank(), 0); absl::c_iota(input_dims, 0); DimensionVector input_dims_to_keep; for (int input_dim : input_dims) { @@ -3238,7 +3406,7 @@ std::tuple GetReductionToVectorDimensions( if (input_dims_to_keep_minor_to_major.empty()) { return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); } - DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0); + DimensionVector input_dims(input_shape.rank(), 0); absl::c_iota(input_dims, 0); absl::Span minor_to_major = LayoutUtil::MinorToMajor(input_shape); @@ -3260,15 +3428,107 @@ std::tuple GetReductionToVectorDimensions( return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor); } -std::tuple ComputeMappingSchemeAndReductionKind( - const HloInstruction* first_reduce, llvm::IRBuilder<>* b) { +// Returns true if all the transitive users of hlo before hitting users in +// use_chain_endings are elementwise operations. +bool AreUsersElementwise(const HloInstruction* hlo, + const ConstHloInstructionSet& use_chain_endings) { + return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) { + return use_chain_endings.count(user) || + (user->IsElementwise() && + AreUsersElementwise(user, use_chain_endings)); + }); +} + +// Returns the number of fusion inputs that have the same dimension as the +// given shape, and involve in only elementwise operations. +int64 NumInputsInvolveInOnlyElementwiseOps( + const HloInstruction* unnested_hlo, const Shape& op_shape, + const ConstHloInstructionSet& use_chain_endings) { + return absl::c_count_if( + unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) { + const Shape& parameter_shape = parameter->shape(); + return ShapeUtil::SameDimensions(op_shape, parameter_shape) && + AreUsersElementwise(parameter, use_chain_endings); + }); +} + +// Returns the number of fusion inputs that have more elements than the given +// shape. +int64 NumInputsWithMoreElementsThan(const HloInstruction* unnested_hlo, + const Shape& shape) { + int64 num_elements = ShapeUtil::ElementsIn(shape); + return absl::c_count_if( + unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) { + return ShapeUtil::ElementsIn(parameter->shape()) > num_elements; + }); +} + +// The benefit of unrolling a kInput fusion that is a column reduction comes +// from the vectorization of non-reduction fusion outputs and fusion inputs. +// On the other hand, unrolling can also introduce factors that can cause +// the kernel to run slower. This routine uses a simple heuristic to estimate +// the benefit as well as the overhead of unrolling in order to decide whether +// unrolling is beneficial for the given kInput fusion. +bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo, + const Shape& input_shape, + int64 num_kept) { + // TODO(b/122468062): Need further investigate to see whether we can + // remove the constraint on IsPowerOfTwo. + if (!IsPowerOfTwo(static_cast(num_kept))) { + return false; + } + + if (unnested_hlo->opcode() == HloOpcode::kReduce) { + return true; + } + + CHECK_EQ(unnested_hlo->opcode(), HloOpcode::kFusion); + int64 can_be_vectorized = 0; + int64 cannot_be_vectorized = 0; + const HloInstruction* fused_root = unnested_hlo->fused_expression_root(); + ConstHloInstructionSet use_chain_endings; + if (fused_root->opcode() == HloOpcode::kReduce) { + use_chain_endings.insert(fused_root); + // Atomic.add of the reduction result can't be vectorized. + cannot_be_vectorized++; + } else { + CHECK_EQ(fused_root->opcode(), HloOpcode::kTuple); + for (const HloInstruction* instr : fused_root->operands()) { + if (instr->opcode() == HloOpcode::kReduce) { + // Atomic.add of the reduction result can't be vectorized. + cannot_be_vectorized++; + } else { + // Write of the non-reduction result can be vectorized. + can_be_vectorized++; + } + use_chain_endings.insert(instr); + } + } + // Fusion inputs that have the same dimension as the reduce input and + // only involve in elementwise operations can be vectorized. + can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps( + unnested_hlo, input_shape, use_chain_endings); + // Fusion inputs with more elements than the reduce op input must participate + // in non-elementwise operations and we assume that they are not vectorizable + // for the purpose of estimating the benefit of unrolling. If the kernel is + // unrolled even with such an assumption, and the accesses to those inputs + // turn out to be vectorizable, the compiler will still vectorize them. + cannot_be_vectorized += + NumInputsWithMoreElementsThan(unnested_hlo, input_shape); + return can_be_vectorized >= cannot_be_vectorized; +} + +} // namespace + +std::tuple +IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( + const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) { int64 depth = 1; int64 height = 1; int64 width = 1; bool is_row_reduction = true; int64 tile_size_x = 1; int64 tile_size_y = 1; - int64 block_size_y = 1; int64 block_size_z = 1; int64 num_threads_x = 1; int64 num_threads_y = 1; @@ -3279,6 +3539,7 @@ std::tuple ComputeMappingSchemeAndReductionKind( std::tie(num_reduced_major, num_kept, num_reduced_minor) = GetReductionToVectorDimensions(input_shape, first_reduce->dimensions()); CHECK_EQ(num_output_elems, num_kept); + bool dilated_x = true; if (num_kept == 1) { // Scalar reduction is a special row reduction with depth = height = 1. @@ -3291,14 +3552,25 @@ std::tuple ComputeMappingSchemeAndReductionKind( height = num_reduced_major; width = num_kept; is_row_reduction = false; - tile_size_x = std::min(kWarpSize, num_kept); - // The old Column reduction algorithm uses kTileHeight = 128. We choose - // tile_size_y * block_size_y = 128 to match the value of kTileHeight. Using - // a non-trivial block_size_y here is a way to avoid unrolling all the 128 - // iterations. - tile_size_y = 32; - block_size_y = 4; - num_threads_x = tile_size_x; + // Column reduction without transpose doesn't require communication among + // threads processing elements in the same tile. The current implementation + // only support the use of one hardware thread block to process one block of + // tiles in the KernelMappingScheme. We try to use one thread to compute + // the partial results for two tensor elements and to maximize the values of + // num_threads_x and tile_size_x to allow a bigger hardware thread block. + int64 hw_threads_per_block_limit = + ThreadsPerBlockLimit(ir_emitter_context_->device_description()); + if (IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape, + num_kept)) { + tile_size_x = std::min(2 * hw_threads_per_block_limit, num_kept); + num_threads_x = tile_size_x / 2; + dilated_x = false; + } else { + tile_size_x = std::min(hw_threads_per_block_limit, num_kept); + num_threads_x = tile_size_x; + } + int64 kNumElementsPerPartialSum = 128; + tile_size_y = kNumElementsPerPartialSum; } else { // Row reduction reduces inputs with dimension [depth, height, width], // where width is the most minor dimension, to dimension [height] . @@ -3321,15 +3593,14 @@ std::tuple ComputeMappingSchemeAndReductionKind( << " " << width; DimensionVector dims_in_elem{depth, height, width}; - DimensionVector req_block_sizes{block_size_z, block_size_y, 1}; - llvm_ir::KernelMappingScheme mapping_scheme(dims_in_elem, tile_size_y, - tile_size_x, req_block_sizes, - num_threads_y, num_threads_x, b); + DimensionVector req_block_sizes{block_size_z, 1, 1}; + llvm_ir::KernelMappingScheme mapping_scheme( + dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y, + num_threads_x, &b_); + mapping_scheme.SetDilatedX(dilated_x); return std::make_tuple(mapping_scheme, is_row_reduction); } -} // namespace - Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); @@ -3375,14 +3646,15 @@ Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) { bool is_row_reduction; llvm_ir::KernelMappingScheme mapping_scheme; std::tie(mapping_scheme, is_row_reduction) = - ComputeMappingSchemeAndReductionKind(first_reduce, &b_); + ComputeMappingSchemeAndReductionKind(unnested_hlo, first_reduce); ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction); KernelCodeGenerator kernel_generator( /*tile_element_generator=*/ [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc) { - EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc); + llvm::Value* x_loc, int64 x_iter_num) { + EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc, + x_iter_num); }, /*block_prologue_generator=*/ [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 85a0e5328c4e436d4522593b38421efc87c42d32..21b842bb2cd63ac454f85556df20ae5877cecbe1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -76,7 +76,6 @@ class IrEmitterUnnested : public IrEmitter { void SetLaneId(llvm::Value* v) { lane_id_ = v; } void SetIndexType(llvm::Type* t) { index_ty_ = t; } void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { - CHECK_EQ(tiled_param_info_, nullptr); tiled_param_info_ = tiled_param_info; } @@ -89,7 +88,7 @@ class IrEmitterUnnested : public IrEmitter { } llvm::Type* GetIndexType() const { return index_ty_; } - private: + protected: llvm_ir::KernelMappingScheme* mapping_scheme_; llvm_ir::TiledParameterInfo* tiled_param_info_; llvm::Value* lane_id_; @@ -109,10 +108,12 @@ class IrEmitterUnnested : public IrEmitter { // y_loc: The y coordinate within a tile. // x_loc: The x coordinate within a tile. // kernel_info: Other information to support the kernel code generation. + // x_iter_num: When a thread process N elements in the X dimension, x_iter_num + // has a value of 0..N-1 to identify the element being process. using TileElementGenerator = std::function; + llvm::Value* x_loc, int64 x_iter_num)>; // KernelCodeGenerator records the code generator objects that generate code // for tile elements or tile block prologue/epilogue. @@ -176,7 +177,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleAfterAll(HloInstruction* after_all) override; Status EmitTargetElementLoop( @@ -215,6 +216,15 @@ class IrEmitterUnnested : public IrEmitter { // Prerequisite: `IsReductionToVector(*unnested_hlo)` Status EmitReductionToVector(HloInstruction* unnested_hlo); + // Computes the KernelMappingScheme for the reduce HLO and indicates whether + // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo + // and first_reduce are the same instruction. For a kInput fusion, + // unnested_hlo is the fusion instruction while first_reduce is the first + // reduce op. + std::tuple + ComputeMappingSchemeAndReductionKind(const HloInstruction* unnested_hlo, + const HloInstruction* first_reduce); + // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. `scatter` may be fused, scatter indices are taken from // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is @@ -238,26 +248,29 @@ class IrEmitterUnnested : public IrEmitter { const KernelCodeGenerator& kernel_generator, KernelCodegenInfo* kernel_info); void EmitBlock(const TileGenerator& emit_one_tile, - const KernelCodegenInfo* kernel_info, - KernelSupportLibrary& ksl, llvm::Type* index_ty); + KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl, + llvm::Type* index_ty); // Emits code to process a tensor element in a tile for the given kCopy HLO // that performs a 0-2-1 transpose. void EmitTileElementForCopy(HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); // Emits code to process a tensor element in a tile for the given kLoop fusion // HLO containing parameters that are 0-2-1 transpose of its outputs. void EmitTileElementForFusion(HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); // Emits code to process a tensor element in a tile for the given input hlo // that is either a unnested kReduce or a kInput fusion. void EmitTileElementForReduction(HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc); + llvm::Value* y_loc, llvm::Value* x_loc, + int64 x_iter_num); // Prepares for the code generation for a tile block of a reduction kernel. void EmitPrologueForReduction(HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info); @@ -272,9 +285,8 @@ class IrEmitterUnnested : public IrEmitter { // For each reducer, emits the shuffle-down loop to accumulate the partial // result to the global result. void EmitFullWarpShuffleDownLoopForAllReduces( - const absl::InlinedVector& reducers, - const absl::InlinedVector& - partial_result_addresses); + absl::Span reducers, + absl::Span partial_result_addresses); // Generates the IrArray for each input of an hlo and returns a vector that // constains such IrArrays. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index 364f69a69d47644b383af9cf6865c93360b82bab..153aab97d9eb971734c5ea95564895631bc2a9fa 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -110,11 +110,9 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, } // Gets the GPU name as it's known to LLVM for a given compute capability. If -// we see an unrecognized compute capability, we return "sm_30". +// we see an unrecognized compute capability, we return "sm_35". static string GetSmName(std::pair compute_capability) { static auto* m = new std::map, int>({ - {{3, 0}, 30}, - {{3, 2}, 32}, {{3, 5}, 35}, {{3, 7}, 37}, {{5, 0}, 50}, @@ -125,8 +123,9 @@ static string GetSmName(std::pair compute_capability) { {{6, 2}, 62}, {{7, 0}, 70}, {{7, 2}, 72}, + {{7, 5}, 75}, }); - int sm_version = 30; + int sm_version = 35; auto it = m->find(compute_capability); if (it != m->end()) { sm_version = it->second; @@ -177,13 +176,6 @@ std::unique_ptr GetTargetMachine( } TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); - llvm_ir::SetTargetOptions( - /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_gpu_enable_fast_math(), - &target_options); - - // Enable FMA synthesis. - target_options.AllowFPOpFusion = FPOpFusion::Fast; // Set the verbose assembly options. target_options.MCOptions.AsmVerbose = false; @@ -206,8 +198,7 @@ std::unique_ptr GetTargetMachine( } return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, - Optional(RelocModel), Optional(CMModel), - codegen_opt_level)); + getRelocModel(), getCodeModel(), codegen_opt_level)); } // Adds the standard LLVM optimization passes, based on the speed optimization @@ -401,8 +392,16 @@ StatusOr CompileModuleToPtx(llvm::Module* module, int32 opt_level = hlo_module_config.debug_options().xla_backend_optimization_level(); - CHECK_GE(opt_level, 2) - << "The XLA GPU backend doesn't support unoptimized code generation"; + if (opt_level < 2) { + LOG(ERROR) << std::string(80, '*'); + LOG(ERROR) << "The XLA GPU backend doesn't support unoptimized code " + "generation but "; + LOG(ERROR) << "--xla_backend_optimization_level is set to " << opt_level + << "!"; + LOG(ERROR) << "(Supported configuration is " + "--xla_backend_optimization_level >= 2.)"; + LOG(ERROR) << std::string(80, '*'); + } AddOptimizationPasses(opt_level, /*size_level=*/0, target_machine.get(), &module_passes, @@ -465,6 +464,9 @@ void GPUBackendInit(const HloModuleConfig& hlo_module_config) { // between those loads. FeedLLVMWithFlags({"-memdep-block-scan-limit=500"}); + // Use div.approx -- it matters for some float-division heavy benchmarks. + FeedLLVMWithFlags({"-nvptx-prec-divf32=0"}); + llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config); // Initialize the NVPTX target; it's the only target we link with, so call its diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 01fddcede64d1bb02ab89db5fc9524893c2d47a4..02e1207f377b8c28bf2566bee8cf3bcbc66794fb 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -67,7 +67,7 @@ int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, } int64 profit = 0; for (auto instr : instr2->operands()) { - if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) { + if (!IsProfitableOperand(instr) || !in_list.contains(instr)) { continue; } profit += ShapeUtil::ByteSizeOf(instr->shape()); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index d16c87ba5c63aa582753fe949e9e39ee2d8b81e5..40b87b16a195564c9b98497f79a70f1db0539d87 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -628,8 +628,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { p.1 = s32[1]{0} parameter(1) p.2 = f16[1,96,1024]{2,1,0} parameter(2) c.0 = s32[] constant(0) - pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } fusion.2 { @@ -638,7 +637,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { p.2 = f16[1,96,1024]{2,1,0} parameter(2) c.0 = s32[] constant(0) pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } ENTRY entry { diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 637b861f70235f17e8e739907a3f262b7004ee7c..1f4f1766618c71c9ef8705f3038676a0518b3ddd 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -36,6 +36,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/convolution_group_converter.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" @@ -50,6 +53,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" @@ -77,6 +81,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/sort_simplifier.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -108,27 +113,33 @@ namespace { namespace tracing = tensorflow::tracing; -// Returns the directory containing nvvm libdevice files. config_cuda_data_dir -// should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the -// HloModule being compiled. -string GetLibdeviceDir(const string& config_cuda_data_dir) { - std::vector potential_libdevice_dirs; - if (!config_cuda_data_dir.empty()) { - potential_libdevice_dirs.push_back(config_cuda_data_dir); - } - potential_libdevice_dirs.push_back(tensorflow::LibdeviceRoot()); - - // Tries all potential libdevice directories in the order they are inserted. - // Returns the first directory that exists in the file system. - for (const string& potential_libdevice_dir : potential_libdevice_dirs) { - if (tensorflow::Env::Default()->IsDirectory(potential_libdevice_dir).ok()) { - VLOG(2) << "Found libdevice dir " << potential_libdevice_dir; - return potential_libdevice_dir; - } - VLOG(2) << "Unable to find potential libdevice dir " - << potential_libdevice_dir; +// Returns a vector of potential locations of the CUDA root directory. +std::vector GetCudaRootCandidates( + const HloModuleConfig& hlo_module_config) { + std::vector potential_cuda_roots = tensorflow::CandidateCudaRoots(); + + // CUDA location explicitly specified by user via --xla_gpu_cuda_data_dir has + // highest priority. + string xla_gpu_cuda_data_dir = + hlo_module_config.debug_options().xla_gpu_cuda_data_dir(); + if (!xla_gpu_cuda_data_dir.empty()) { + potential_cuda_roots.insert(potential_cuda_roots.begin(), + xla_gpu_cuda_data_dir); } + return potential_cuda_roots; +} +// Returns the directory containing nvvm libdevice files. +string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { + for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { + string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; + } + } LOG(WARNING) << "Unable to find libdevice dir. Using '.'"; // Last resort: maybe in the current folder. return "."; @@ -143,9 +154,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, Compiler* compiler) { { HloPassPipeline pipeline("optimization"); - pipeline.AddPass(); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pipeline.AddPass(); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), @@ -153,6 +164,14 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); + auto cost_model = [](HloInstruction* conv) { + // We need a cost model for GPUs. Currently, do nothing. + return false; + }; + pipeline.AddPass(false); + pipeline.AddPass( + cost_model, + /*convert_batch_groups_only=*/true); // Convert BF16 operations to F32 operations so that the GPU backend can // support BF16 operations without directly implementing a BF16 lowering for // most ops. @@ -175,14 +194,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); + pipeline.AddPass(); + // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. pipeline.AddPass(); - AlgebraicSimplifierOptions options( - [](const Shape&, const Shape&) { return false; }); - options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifierOptions options; pass.AddPass(options); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -251,12 +271,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - AlgebraicSimplifierOptions options( - /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { - return true; - }); + AlgebraicSimplifierOptions options; options.set_is_layout_sensitive(true); - options.set_enable_permutation_sort_replacement(true); pipeline.AddPass>(options); // Choose the fastest algorithm for each conv. @@ -360,6 +376,7 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); return pipeline.Run(hlo_module).status(); } @@ -477,13 +494,19 @@ void WarnIfBadDriverJITVersion() { // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. -StatusOr> CompilePtx(const string& ptx, int cc_major, - int cc_minor) { +StatusOr> CompilePtx( + const string& ptx, int cc_major, int cc_minor, + const HloModuleConfig& hlo_module_config) { tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); - const string ptxas_path = - tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); - VLOG(2) << "Checking ptxas at " << ptxas_path; auto env = tensorflow::Env::Default(); + string ptxas_path; + for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { + ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", "ptxas"); + VLOG(2) << "Looking for ptxas at " << ptxas_path; + if (env->FileExists(ptxas_path).ok()) { + break; + } + } TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); VLOG(2) << "Using ptxas at " << ptxas_path; @@ -518,6 +541,9 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } + if (hlo_module_config.debug_options().xla_gpu_disable_ptxas_optimizations()) { + ptxas_args.push_back("-O0"); + } ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); @@ -680,12 +706,8 @@ StatusOr> NVPTXCompiler::RunBackend( // Find the directory containing libdevice. To avoid searching for it every // time, we have a one-element cache, keyed on the module's config's // cuda_data_dir. - const auto& config_cuda_data_dir = - module->config().debug_options().xla_gpu_cuda_data_dir(); - if (cached_libdevice_dir_.empty() || - cached_cuda_data_dir_ != config_cuda_data_dir) { - cached_cuda_data_dir_ = config_cuda_data_dir; - cached_libdevice_dir_ = GetLibdeviceDir(config_cuda_data_dir); + if (cached_libdevice_dir_.empty()) { + cached_libdevice_dir_ = GetLibdeviceDir(module->config()); } libdevice_dir = cached_libdevice_dir_; } @@ -739,7 +761,7 @@ StatusOr> NVPTXCompiler::RunBackend( } const std::vector cubin = - CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); + CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor, module->config()); auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), @@ -771,9 +793,9 @@ StatusOr> NVPTXCompiler::RunBackend( return std::unique_ptr(gpu_executable); } -std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, - int cc_minor) { +std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + const HloModuleConfig& hlo_module_config) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; @@ -802,7 +824,7 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, CHECK(!cache_value->compilation_done); if (!ptx.empty()) { StatusOr> maybe_cubin = - CompilePtx(*cache_ptx, cc_major, cc_minor); + CompilePtx(*cache_ptx, cc_major, cc_minor, hlo_module_config); if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index f79ae2990ae7d6e6985b15727a72358289121aa9..b2077f42fd097330703fde063d80a20704fa48e2 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -97,8 +97,9 @@ class NVPTXCompiler : public LLVMCompiler { // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin. If compilation was unsuccessful, returns an empty vector. - std::vector CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, int cc_minor); + std::vector CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + const HloModuleConfig& hlo_module_config); // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index 375f68a15957936151aee068582a714b62694af2..bfed4f5230dfe37bca48560ce83a2dd82c8950a4 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -39,6 +39,25 @@ std::ostream& operator<<(std::ostream& out, return out; } +int64 ThreadsPerBlockLimit(const se::DeviceDescription& device_desc) { + int64 threads_per_block = device_desc.threads_per_block_limit(); + if (threads_per_block == 0) { + static std::atomic log_count{0}; + if (log_count.fetch_add(1) < 8) { + LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " + "without full information about its capabilities. " + "StreamExecutor's PopulateDeviceDescription should be " + "updated for this device."; + } + threads_per_block = device_desc.threads_per_warp(); + if (threads_per_block == 0) { + // Fall back to *something* if we can't even get num threads per warp. + threads_per_block = 32; + } + } + return threads_per_block; +} + // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& device_desc, @@ -62,21 +81,7 @@ LaunchDimensions CalculateLaunchDimensions( // // * = - int64 threads_per_block = device_desc.threads_per_block_limit(); - if (threads_per_block == 0) { - static std::atomic log_count{0}; - if (log_count.fetch_add(1) < 8) { - LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " - "without full information about its capabilities. " - "StreamExecutor's PopulateDeviceDescription should be " - "updated for this device."; - } - threads_per_block = device_desc.threads_per_warp(); - if (threads_per_block == 0) { - // Fall back to *something* if we can't even get num threads per warp. - threads_per_block = 32; - } - } + int64 threads_per_block = ThreadsPerBlockLimit(device_desc); if (num_elements < threads_per_block) { threads_per_block = num_elements; diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 02471129e004b4876ce20a62cade34060c65b478..eb41dcccb938ccc088c2371def96ca73276771ab 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -57,6 +57,9 @@ class LaunchDimensions { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims); +// Returns the maximum number of threads per block allowed by the device. +int64 ThreadsPerBlockLimit(const se::DeviceDescription& device_desc); + LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& device_desc, int unroll_factor = 1); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 4775baf44aecfe6adaf2bf0d2791595436635b16..1dedbd3befce6e2ceb06126d83a061207a90dd8f 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -25,7 +26,7 @@ namespace xla { namespace gpu { bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const { - return hlo_to_stream_number_.count(&hlo); + return hlo_to_stream_number_.contains(&hlo); } int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const { @@ -98,10 +99,10 @@ int ComputeStreamToAssign( // greedy approach. First, we compute as forbidden_stream_numbers the // streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign // `hlo` a different stream. - std::set forbidden_stream_numbers; + absl::flat_hash_set forbidden_stream_numbers; for (const auto* seen_gemm : seen_gemms) { int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm); - if (!forbidden_stream_numbers.count(stream_num) && + if (!forbidden_stream_numbers.contains(stream_num) && CanRunConcurrently(*seen_gemm, hlo, reachability)) { forbidden_stream_numbers.insert(stream_num); } @@ -109,7 +110,7 @@ int ComputeStreamToAssign( for (int stream_num = 0; stream_num < stream_assignment.StreamCount(); ++stream_num) { - if (!forbidden_stream_numbers.count(stream_num)) { + if (!forbidden_stream_numbers.contains(stream_num)) { return stream_num; } } diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 1fc46bafa10e7ba6c896f081d5c836bd400886c9..92e4d6dbbc1bd564657f8a5de09d23d5ae81a93e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc index d0ccd8619bde9ddd560989380b403efed5c5f42c..5e524faab18947f5793dc2ae34e9329a446d4235 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -75,16 +75,16 @@ class GpuFtzDisabledTest : public GpuFtzTest { // Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise. TEST_F(GpuFtzEnabledTest, MultiplyFtz) { CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( - CHECK-NOT: mul.f32 - CHECK: mul.ftz.f32 - CHECK-NOT: mul.f32 + CHECK-NOT: mul.rn.f32 + CHECK: mul.rn.ftz.f32 + CHECK-NOT: mul.rn.f32 )"); } TEST_F(GpuFtzDisabledTest, MultiplyFtz) { CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( - CHECK-NOT: mul.ftz.f32 - CHECK: mul.f32 - CHECK-NOT: mul.ftz.f32 + CHECK-NOT: mul.rn.ftz.f32 + CHECK: mul.rn.f32 + CHECK-NOT: mul.rn.ftz.f32 )"); } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index a302b582ede3723acd118d2e4a4bb3efdf7a4d0b..869724db601b2d5e4ed6d3c7bf3e10a748433146 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -65,7 +65,7 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -91,7 +91,7 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -118,7 +118,7 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -152,7 +152,7 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); @@ -187,13 +187,13 @@ TEST_F(GpuKernelTilingTest, CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); } -TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { +TEST_F(GpuKernelTilingTest, TransposedInputWithUserReverseNotTiled) { const char *const kHloString = R"( HloModule FusionTransposeWithReverseNotTiled fused_computation.1 { @@ -214,12 +214,203 @@ TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion -; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK-NOT: call void @llvm.nvvm.barrier0() ; CHECK: } )", /*match_optimized_ir=*/true); } +TEST_F(GpuKernelTilingTest, TransposedInputWithUserBitcastNotTiled) { + const char *const kHloString = R"( + HloModule TransposedInputWithUserBitcast + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + ROOT bitcast = f32[20,20]{0,1} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = f32[20,20]{0,1} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, TransposedInputWithoutUnsafeUseTiled) { + const char *const kHloString = R"( + HloModule TwoTransposedInputs + + fused_computation { + param_0 = f32[64,64]{1,0} parameter(0) + param_1 = f32[64,64]{1,0} parameter(1) + bitcast = f32[64,64]{0,1} bitcast(param_0) + copy = f32[64,64]{0,1} copy(param_1) + ROOT tuple = (f32[64,64]{0,1}, f32[64,64]{0,1}) tuple(bitcast, copy) + } + + ENTRY kernel_entry { + parameter.0 = f32[64,64]{1,0} parameter(0) + parameter.1 = f32[64,64]{1,0} parameter(1) + ROOT fusion = (f32[64,64]{0,1}, f32[64,64]{0,1}) + fusion(parameter.0, parameter.1), + kind=kLoop, calls=fused_computation + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, ColumnReductionWithPowerOf2OutputElementsUnrolled) { + const char *const kHloString = R"( + HloModule column_reduce_powerof2 + + reduction { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + ENTRY kernel_entry { + constant0 = f32[] constant(0) + arg1 = f16[1024,512]{1,0} parameter(0) + arg1_conv = f32[1024,512]{1,0} convert(arg1) + ROOT reduce = f32[512]{0} reduce(arg1_conv, constant0), dimensions={0}, to_apply=reduction + })"; + + // Check that two calls to llvm.nvvm.atomic are generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} + +TEST_F(GpuKernelTilingTest, + ColumnReductionWithInputLargerThenReduceInputNotUnrolled) { + const char *const kHloString = R"( + HloModule larger_than_reduce_input_parameter + + reduction22 { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + fused_computation { + constant0 = f32[] constant(0) + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1027,513]{1,0} parameter(1) + arg1.conv = f32[1024,512]{1,0} convert(arg.1) + arg2.conv = f32[1027,513]{1,0} convert(arg.2) + slice2 = f32[1024,512]{1,0} slice(arg2.conv), slice={[2:1026], [1:513]} + add2 = f32[1024,512]{1,0} add(arg1.conv, slice2) + ROOT reduce = f32[512]{0} reduce(add2, constant0), dimensions={0}, + to_apply=reduction22 + } + + ENTRY kernel_entry { + arg1 = f16[1024,512]{1,0} parameter(0) + arg2 = f16[1027,513]{1,0} parameter(1) + ROOT fusion = f32[512]{0} fusion(arg1, arg2), kind=kInput, + calls=fused_computation + })"; + + // Check that one call to llvm.nvvm.atomic is generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} + +TEST_F(GpuKernelTilingTest, ColumnReductionMOFUnrolled) { + const char *const kHloString = R"( + HloModule column_reduce_powerof2_mof + + reduction22 { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + fused_computation { + constant0 = f32[] constant(0) + arg.1 = f16[1024,512]{1,0} parameter(0) + arg.2 = f16[1024,512]{1,0} parameter(1) + arg1.conv = f32[1024,512]{1,0} convert(arg.1) + arg2.conv = f32[1024,512]{1,0} convert(arg.2) + reduce1 = f32[512]{0} reduce(arg1.conv, constant0), dimensions={0}, + to_apply=reduction22 + reduce2 = f32[512]{0} reduce(arg2.conv, constant0), dimensions={0}, + to_apply=reduction22 + add = f32[1024,512]{1,0} add(arg1.conv, arg2.conv) + ROOT tuple = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0}) + tuple(reduce1, reduce2, add) + } + + ENTRY kernel_entry { + arg1 = f16[1024,512]{1,0} parameter(0) + arg2 = f16[1024,512]{1,0} parameter(1) + ROOT fusion = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0}) + fusion(arg1, arg2), kind=kInput, calls=fused_computation + })"; + + // Check that four calls to llvm.nvvm.atomic are generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK-NOT: call float @llvm.nvvm.atomic.load.add.f32.p0f32 +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index f8120a5fa00ce38644cd85c54d5ef65701be1eda..f91a22d482bc8bc046977870a7a4d18ca1acde68 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -43,7 +43,7 @@ class InfeedTest : public ClientLibraryTestBase { ASSERT_IS_OK(client_->TransferToInfeed(literal)); XlaBuilder builder(TestName()); Infeed(&builder, literal.shape()); - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { // TODO(b/30609564): Use ComputeAndCompareLiteral instead. ComputeAndCompareTuple(&builder, literal, {}); } else { diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 6b2d76764a077dc6cfa3f9ddc6e525ab330323be..25bad67bab9375559c431466571c62acd0452b01 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -14,17 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/map_util.h" namespace xla { namespace gpu { void ThunkSchedule::AddDependenciesOnTransitiveOperands( const Thunk& thunk, const HloInstruction& operand, - const std::unordered_map& hlo_to_thunk) { - if (hlo_to_thunk.count(&operand)) { + const absl::flat_hash_map& hlo_to_thunk) { + if (hlo_to_thunk.contains(&operand)) { // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency // list if `operand` is assigned to a different stream. As an optimization, // we skip `operand`'s operands because `operand` depends on them already. @@ -48,14 +50,14 @@ ThunkSchedule::ThunkSchedule( const std::vector& hlo_total_order) : thunks_(std::move(thunks)), stream_assignment_(std::move(stream_assignment)) { - std::unordered_map hlo_to_thunk; + absl::flat_hash_map hlo_to_thunk; for (const auto& thunk : *thunks_) { InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); } for (HloInstruction* hlo : hlo_total_order) { - if (hlo_to_thunk.count(hlo)) { - thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo)); + if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) { + thunk_total_order_.push_back(*thunk); } } @@ -106,7 +108,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() { // redundant dependency edge. Array2D last_dependency(stream_count, stream_count, -1); for (const Thunk* dst : thunk_total_order_) { - if (!depends_on_.count(dst)) { + if (!depends_on_.contains(dst)) { continue; } @@ -134,7 +136,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() { const std::list& ThunkSchedule::DependsOn( const Thunk* thunk) const { - if (depends_on_.count(thunk)) { + if (depends_on_.contains(thunk)) { return FindOrDie(depends_on_, thunk); } else { return empty_thunk_list_; diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h index 43b628a1baf0e79a3197f3cfad3547991642eaed..549378debd52417252724a5d8a6f4d24f2ad0369 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -54,7 +56,9 @@ class ThunkSchedule { // Thunks that `thunk` depends on. const std::list& DependsOn(const Thunk* thunk) const; // Whether `thunk` is depended by another thunk. - bool Depended(const Thunk* thunk) const { return depended_by_.count(thunk); } + bool Depended(const Thunk* thunk) const { + return depended_by_.contains(thunk); + } // Delegates to StreamAssignment. int StreamCount() const { return stream_assignment_->StreamCount(); } @@ -75,13 +79,13 @@ class ThunkSchedule { // thunk.hlo_instruction(). void AddDependenciesOnTransitiveOperands( const Thunk& thunk, const HloInstruction& operand, - const std::unordered_map& hlo_to_thunk); + const absl::flat_hash_map& hlo_to_thunk); std::unique_ptr thunks_; std::vector thunk_total_order_; - std::unordered_map> depends_on_; - std::set depended_by_; + absl::flat_hash_map> depends_on_; + absl::flat_hash_set depended_by_; std::list empty_thunk_list_; std::unique_ptr stream_assignment_; diff --git a/tensorflow/compiler/xla/service/gpu/xfeed_queue.h b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h index dd46ff433ba0ad6bfa3999b96845fdaebe148aca..167c038420a64d9fa29746ed3fe349620e08e6ff 100644 --- a/tensorflow/compiler/xla/service/gpu/xfeed_queue.h +++ b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h @@ -47,6 +47,10 @@ class XfeedQueue { // Blocks until the queue is non-empty, then returns the buffer at the head of // the queue. BufferType BlockingGetNextDestination() { + for (const auto& callback : before_get_next_dest_callbacks_) { + callback(); + } + bool became_empty; BufferType current_buffer; { @@ -69,6 +73,10 @@ class XfeedQueue { void RegisterOnEmptyCallback(std::function callback) { on_empty_callbacks_.push_back(std::move(callback)); } + void RegisterBeforeGetNextDestinationCallback( + std::function callback) { + before_get_next_dest_callbacks_.push_back(std::move(callback)); + } private: tensorflow::mutex mu_; @@ -82,6 +90,11 @@ class XfeedQueue { // List of callbacks which will be called when 'enqueued_buffers_' becomes // empty. std::vector> on_empty_callbacks_; + + // List of callbacks which will be called before BlockingGetNextDestination() + // is called. This lets you e.g. call EnqueueDestination() for each call to + // BlockingGetNextDestination(). + std::vector> before_get_next_dest_callbacks_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 9220865867b770eebfb1ada8f31a5d24693a4b8d..4fca981c6a59cdb91a997e6a887fd26472c1a10a 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -199,7 +199,7 @@ Status HeapSimulator::RunComputation( // If the buffer has no users and isn't an entry parameter or output, it // must be a dead value. - if (live_buffers.count(buffer) == 0) { + if (!live_buffers.contains(buffer)) { dead_buffers_to_free.push_back(buffer); } } @@ -225,10 +225,10 @@ Status HeapSimulator::RunComputation( } } // Sort to get a deterministic iteration order. - std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), - [](const BufferValue* x, const BufferValue* y) { - return x->id() < y->id(); - }); + absl::c_sort(operand_buffers_to_free, + [](const BufferValue* x, const BufferValue* y) { + return x->id() < y->id(); + }); // Allocate buffers defined by this instruction. This is the latest point // that we can allocate; right before the buffer is first used. This must @@ -253,7 +253,7 @@ Status HeapSimulator::RunComputation( bool shared = false; if (options_.may_reuse_operand_buffers) { for (const BufferValue* operand_buffer : operand_buffers_to_free) { - if (reused_buffers.count(operand_buffer) != 0) { + if (reused_buffers.contains(operand_buffer)) { continue; } if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && @@ -335,10 +335,9 @@ Status HeapSimulator::RunComputation( to_free.push_back(buffer); } - std::sort(to_free.begin(), to_free.end(), - [](const BufferValue* x, const BufferValue* y) { - return x->id() < y->id(); - }); + absl::c_sort(to_free, [](const BufferValue* x, const BufferValue* y) { + return x->id() < y->id(); + }); for (const BufferValue* buffer : to_free) { VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); @@ -374,15 +373,15 @@ bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const { return true; } return options_.buffers_to_assign != nullptr && - options_.buffers_to_assign->count(buffer) == 0; + !options_.buffers_to_assign->contains(buffer); } // Alloc always calls the underlying heap algorithm. void HeapSimulator::Alloc(const BufferValue* buffer, const HloInstruction* instruction) { - CHECK(allocated_buffers_.count(buffer) == 0) + CHECK(!allocated_buffers_.contains(buffer)) << "Alloc called on allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "Alloc called on freed buffer: " << *buffer; allocated_buffers_.insert(buffer); @@ -411,9 +410,9 @@ void HeapSimulator::Free(const BufferValue* buffer, buffer = group->canonical; } - CHECK(allocated_buffers_.count(buffer) > 0) + CHECK(allocated_buffers_.contains(buffer)) << "Free called on non-allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "Free called on freed buffer: " << *buffer; freed_buffers_.insert(buffer); @@ -433,11 +432,11 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer, const HloInstruction* instruction) { CHECK_LE(size_fn_(*buffer), size_fn_(*shared)) << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared; - CHECK(allocated_buffers_.count(buffer) == 0) + CHECK(!allocated_buffers_.contains(buffer)) << "ShareBuffer called on allocated buffer: " << *buffer; - CHECK(freed_buffers_.count(buffer) == 0) + CHECK(!freed_buffers_.contains(buffer)) << "ShareBuffer called on freed buffer: " << *buffer; - CHECK(freed_buffers_.count(shared) == 0) + CHECK(!freed_buffers_.contains(shared)) << "ShareBuffer called on freed shared buffer: " << *shared; const BufferValue* canonical = nullptr; @@ -452,7 +451,7 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer, } else { // The 'shared' buffer doesn't have a group; it must be the canonical. Add // both 'buffer' and 'shared' to a new group. - CHECK(allocated_buffers_.count(shared) > 0) + CHECK(allocated_buffers_.contains(shared)) << "ShareBuffer called on non-allocated shared buffer: " << *shared; auto group = std::make_shared(); canonical = shared; @@ -596,7 +595,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() { } // Call ops in the run sorted by decreasing size, breaking ties by buffer id. - std::sort(run_.begin(), run_.end(), [](const Op& a, const Op& b) { + absl::c_sort(run_, [](const Op& a, const Op& b) { if (a.size != b.size) { return a.size > b.size; } @@ -866,23 +865,23 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { for (auto& entry : buffer_intervals_) { sorted_buffer_intervals.push_back(entry.second); } - std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(), - [](const BufferInterval& x, const BufferInterval& y) { - if (x.size != y.size) { - return x.size > y.size; - } - if (x.end - x.start != y.end - y.start) { - return x.end - x.start > y.end - y.start; - } - return x.buffer->id() < y.buffer->id(); - }); + absl::c_sort(sorted_buffer_intervals, + [](const BufferInterval& x, const BufferInterval& y) { + if (x.size != y.size) { + return x.size > y.size; + } + if (x.end - x.start != y.end - y.start) { + return x.end - x.start > y.end - y.start; + } + return x.buffer->id() < y.buffer->id(); + }); BufferIntervalTree interval_tree(sorted_buffer_intervals.size()); for (auto& buffer_interval : sorted_buffer_intervals) { auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime( buffer_interval.start, buffer_interval.end); - std::sort( - chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(), + absl::c_sort( + chunks_overlapping_in_time, [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; }); // Find the minimum free chunk that can hold this buffer. diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index dbbf43082f2c1d21f5ef42f53804bf0969903a58..3e0631aeb4aa374cb5748650e1c7529e26e10b34 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -158,7 +158,7 @@ class HeapSimulator { void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, const BufferValue* buffer, const HloInstruction* instruction, - const BufferValue* shared_with_canonical); + const BufferValue* share_with_canonical); // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap, // in which case we are calculating the same allocs/frees twice in the diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 414c63271245315f037d04924c9291a9cd5b7a77..263b42a29dbb0dbc0fb6eca7968674ff242f45ed 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 58 +// Next ID: 59 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -82,6 +82,8 @@ message HloInstructionProto { // it will use a default value of 1. int64 feature_group_count = 50; + int64 batch_group_count = 58; + // Describes the [begin, end) index range and stride for slices. message SliceDimensions { int64 start = 1; @@ -166,7 +168,7 @@ message HloInstructionProto { // Cross replica op fields. repeated ReplicaGroup replica_groups = 49; int64 all_reduce_id = 45; - string cross_replica_sum_barrier = 46; + string all_reduce_barrier = 46; // Whether this Send/Recv instruction transfers data to/from the host. Only // present for Send and Recv instructions and their SendDone and RecvDone @@ -227,6 +229,18 @@ message HloScheduleProto { } message HloInputOutputAliasProto { + enum Kind { + // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 + // behavior and missing has_*() APIs. + UNDEFINED_ALIAS = 0; + // An alias setup by the user as must alias. A use setting USER_ALIAS is + // expecting the designed output to be dropped over the given input + // parameter number+index. + USER_ALIAS = 1; + // An alias setup by the compiler as part of its optimizations. + SYSTEM_ALIAS = 2; + } + // The following proto describes a pair of aliased an input // (described by parameter number and a ShapeIndex of the parameter) // and an output (described by a ShapeIndex of the root @@ -247,6 +261,8 @@ message HloInputOutputAliasProto { int64 parameter_number = 2; // ShapeIndex of the parameter instruction. repeated int64 parameter_shape_index = 3; + // The kind of alias to be setup. + Kind kind = 4; } repeated AliasEntryProto entries = 1; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index cf8e6594cbe5ffd28ca75dd5006e8817f1e8581c..e511f1951c5dd07ebb64fa38fd5b7f6a0e87b429 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -117,7 +117,7 @@ class BufferValueMap { for (const auto& pair : buffers_) { buffer_numbers.push_back(pair.first); } - std::sort(buffer_numbers.begin(), buffer_numbers.end()); + absl::c_sort(buffer_numbers); return buffer_numbers; } @@ -176,13 +176,12 @@ class BufferValueMap { const HloValue& value, std::vector* aliased_buffers) { // Get parameter value from an aliased_input object. const auto get_parameter_value = - [this](const std::pair& aliased_input) + [this](const HloInputOutputAliasConfig::Alias& aliased_input) -> const HloValue& { - int64 param_number = aliased_input.first; - const ShapeIndex& param_index = aliased_input.second; return dataflow_.GetUniqueValueAt( - module_->entry_computation()->parameter_instruction(param_number), - param_index); + module_->entry_computation()->parameter_instruction( + aliased_input.parameter_number), + aliased_input.parameter_index); }; // If the value shows up in a root instruction, alias it with parameter @@ -319,7 +318,7 @@ class BufferValueMap { ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. - std::sort(aliased_buffers.begin(), aliased_buffers.end()); + absl::c_sort(aliased_buffers); aliased_buffers.erase( std::unique(aliased_buffers.begin(), aliased_buffers.end()), aliased_buffers.end()); @@ -367,7 +366,7 @@ std::vector HloAliasAnalysis::ComputeBuffersAt( } // Sort and uniquify vector before returning. - std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan); + absl::c_sort(buffers, HloBuffer::IdLessThan); buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end()); return buffers; @@ -430,8 +429,7 @@ Status HloAliasAnalysis::Verify() const { for (const auto& pair : value_to_buffer_) { const HloValue* value = pair.first; const HloBuffer& buffer = *pair.second; - TF_RET_CHECK(std::find(buffer.values().begin(), buffer.values().end(), - value) != buffer.values().end()); + TF_RET_CHECK(absl::c_linear_search(buffer.values(), value)); } for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) { @@ -457,7 +455,7 @@ string HloAliasAnalysis::ToString() const { for (const HloComputation* computation : module_->computations()) { for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { ShapeUtil::ForEachSubshape( instruction->shape(), [&out, &instruction, this](const Shape&, const ShapeIndex& index) { @@ -515,7 +513,7 @@ StatusOr> HloAliasAnalysis::Run( auto& value_set = buffer_map.GetValuesInBuffer(buffer_number); std::vector sorted_values(value_set.begin(), value_set.end()); - std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan); + absl::c_sort(sorted_values, HloValue::IdLessThan); alias_analysis->buffers_.emplace_back(next_id++, sorted_values); for (const HloValue* value : sorted_values) { alias_analysis->value_to_buffer_[value] = @@ -533,11 +531,11 @@ bool HloAliasAnalysis::HasLiveRangeInterference( const HloOrdering& ordering) const { for (const HloBuffer& buffer : buffers()) { CHECK(!buffer.values().empty()); - if (ShapeUtil::IsToken(buffer.values().front()->shape())) { + if (buffer.values().front()->shape().IsToken()) { // Tokens have no on-device representation and cannot interfere. for (const HloValue* value : buffer.values()) { // If one of the values is a token, all values must be a token. - DCHECK(ShapeUtil::IsToken(value->shape())); + DCHECK(value->shape().IsToken()); } continue; } @@ -547,16 +545,15 @@ bool HloAliasAnalysis::HasLiveRangeInterference( // tie-break using value ID. The tie-break is necessary because we need a // strict weak order for std::sort. std::vector values = buffer.values(); - std::sort(values.begin(), values.end(), - [&ordering](const HloValue* a, const HloValue* b) { - if (ordering.IsDefinedBefore(*a, *b)) { - return true; - } else if (ordering.IsDefinedBefore(*b, *a)) { - return false; - } else { - return a->id() < b->id(); - } - }); + absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) { + if (ordering.IsDefinedBefore(*a, *b)) { + return true; + } else if (ordering.IsDefinedBefore(*b, *a)) { + return false; + } else { + return a->id() < b->id(); + } + }); // Walk through the ordered vector of values. First verify that the values // are totally ordered with respect to 'ordering', then check that no diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 7e6150e94153cd15463725e862ce1b8593f2c991..b6dbf07959c541bceaa8eda5a0101503970ee832 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -238,13 +238,16 @@ TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); // Cannot alias an output twice. ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -279,13 +282,16 @@ TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); // Cannot alias an output twice. ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -365,9 +371,11 @@ TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2})); module_->AddEntryComputation(builder.Build()); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( - /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); const HloAliasAnalysis& analysis = RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index 9c3aa0e64d119c2560f4955d0bcb492519fa52a2..32e48651b30bace4723169935d1f10dd7d7bfec3 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -49,7 +49,7 @@ std::vector HloBuffer::ComputePositions() const { value->positions().end()); } // Remove duplicates and sort positions. - std::sort(positions.begin(), positions.end()); + absl::c_sort(positions); positions.erase(std::unique(positions.begin(), positions.end()), positions.end()); return positions; diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ff122b529bdcdcc69d2245136e19101902dbf957..40fe91398be33f5681e1389e1b6fadcbd87487bb 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include #include +#include #include #include #include @@ -207,14 +207,14 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( TF_RET_CHECK(instruction->user_count() == 0); TF_RET_CHECK(IsRemovable(instruction)) << "Cannot remove instruction: " << instruction->ToString(); - std::unordered_set removed; + absl::flat_hash_set removed; std::queue worklist; worklist.push(instruction); while (!worklist.empty()) { HloInstruction* item = worklist.front(); worklist.pop(); - if (removed.count(item) != 0 || item->user_count() != 0 || + if (removed.contains(item) || item->user_count() != 0 || item == root_instruction() || !IsRemovable(item) || (item->HasSideEffect() && item != instruction)) { continue; @@ -332,7 +332,7 @@ void HloComputation::ComputeInstructionPostOrder( dfs_stack.emplace_back(op); } - // Add inputs for send->recv_done dependencies and cross-replica-sum + // Add inputs for send->recv_done dependencies and all-reduce // dependencies. switch (current->opcode()) { case HloOpcode::kRecvDone: { @@ -344,7 +344,7 @@ void HloComputation::ComputeInstructionPostOrder( } break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { auto all_reduce_id = current->all_reduce_id(); if (all_reduce_id) { auto it = channel_dependency_map.find(all_reduce_id.value()); @@ -372,7 +372,7 @@ HloComputation::ComputeChannelDependencies() const { instruction.get()); break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { auto all_reduce_id = instruction->all_reduce_id(); if (all_reduce_id) { auto& dependencies = channel_dependency_map[all_reduce_id.value()]; @@ -396,6 +396,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { post_order.reserve(instruction_count()); std::vector trace_instructions; absl::flat_hash_map visited; + visited.reserve(instruction_count()); for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -530,11 +531,10 @@ HloComputation::CreateFromProto( HloInstruction* root = instruction_map.at(proto.root_id()); // Sort the instructions in the proto id's order. - std::sort(instructions.begin(), instructions.end(), - [&](const std::unique_ptr& a, - const std::unique_ptr& b) { - return to_proto_id[a.get()] < to_proto_id[b.get()]; - }); + absl::c_sort(instructions, [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); TF_RETURN_IF_ERROR([&]() -> Status { std::vector parameters_seen(parameter_count); @@ -599,7 +599,7 @@ StatusOr HloComputation::DeepCopyHelper( const std::function< HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, HloComputation* computation)>& copy_leaf) { - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { std::vector elements; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); i++) { @@ -616,14 +616,14 @@ StatusOr HloComputation::DeepCopyHelper( } return AddInstruction(HloInstruction::CreateTuple(elements)); } - if (ShapeUtil::IsToken(instruction->shape())) { + if (instruction->shape().IsToken()) { // Tokens have no on-device representation and cannot be copied. Pass // through transparently. return instruction; } // Array shape. - TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape())); + TF_RET_CHECK(instruction->shape().IsArray()); return copy_leaf(instruction, *index, this); } @@ -693,25 +693,37 @@ bool HloComputation::operator==(const HloComputation& other) const { if (this == &other) { return true; } - std::set> visited; - std::function eq = - [&visited, &eq](const HloInstruction* a, const HloInstruction* b) { - // If are visited but not identical, the recursion should have - // been aborted. So, if are visited at this point, they must be - // identical. - if (visited.count(std::make_pair(a, b)) > 0) { - return true; - } - visited.emplace(a, b); - return a->Identical( - *b, eq, [](const HloComputation* a, const HloComputation* b) { - return *a == *b; - }); - }; - return eq(root_instruction(), other.root_instruction()); -} + absl::flat_hash_set> + visited; + std::vector> worklist; -uint64 HloComputation::Hash() const { return root_instruction()->Hash(); } + worklist.push_back({root_instruction(), other.root_instruction()}); + + while (!worklist.empty()) { + auto pair = worklist.back(); + worklist.pop_back(); + + if (visited.contains(pair)) { + continue; + } + visited.emplace(pair); + // TODO(b/123082518): Avoid recursively invoking == becasue it may + // cause a stack overflow with deeply nested subcomputations. + bool identical_ignoring_operands = pair.first->Identical( + *pair.second, + [](const HloInstruction*, const HloInstruction*) { return true; }, + [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }); + if (!identical_ignoring_operands) { + return false; + } + for (size_t i = 0; i < pair.first->operands().size(); ++i) { + worklist.push_back({pair.first->operand(i), pair.second->operand(i)}); + } + } + return true; +} Status HloComputation::ReplaceWithNewInstruction( HloInstruction* old_instruction, @@ -797,20 +809,19 @@ Status HloComputation::AcceptWithOperandOrder( template Status HloComputation::AcceptOrdered( DfsHloVisitorBase* visitor, - const std::vector& order) const { + absl::Span order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { - TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) - << root->ToString(); + TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString(); } TF_RET_CHECK(order.size() == instruction_count()); - std::unordered_set visited; + absl::flat_hash_set visited; for (const HloInstruction* instruction : order) { VLOG(3) << "Visiting ordered: " << instruction->ToString(); - TF_RET_CHECK(instruction_iterators_.count(instruction) == 1) + TF_RET_CHECK(instruction_iterators_.contains(instruction)) << "Instruction " << instruction->name() << " is not in computation " << name(); - TF_RET_CHECK(visited.count(instruction) == 0) + TF_RET_CHECK(!visited.contains(instruction)) << "Instruction " << instruction->name() << " appears more than once in order"; HloInstruction* mutable_instruction = @@ -827,9 +838,9 @@ Status HloComputation::AcceptOrdered( // Explicit instantiations. template Status HloComputation::AcceptOrdered( - DfsHloVisitor*, const std::vector&) const; + DfsHloVisitor*, absl::Span) const; template Status HloComputation::AcceptOrdered( - ConstDfsHloVisitor*, const std::vector&) const; + ConstDfsHloVisitor*, absl::Span) const; Status HloComputation::Accept( const std::function& visitor_func) { @@ -846,29 +857,31 @@ Status HloComputation::Accept( std::unique_ptr HloComputation::Clone( const string& suffix, HloCloneContext* context) { return CloneWithReplacements( - /*replacements=*/std::unordered_map>(), - context, suffix); + /*replacements=*/absl::flat_hash_map>(), + /*extra_parameters=*/{}, context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r1, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r1, std::pair> r2, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); replacements.emplace(std::move(r2)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacementPairs( @@ -876,17 +889,19 @@ std::unique_ptr HloComputation::CloneWithReplacementPairs( std::pair> r2, std::pair> r3, HloCloneContext* context, const string& suffix) { - std::unordered_map> + absl::flat_hash_map> replacements; replacements.emplace(std::move(r1)); replacements.emplace(std::move(r2)); replacements.emplace(std::move(r3)); - return CloneWithReplacements(std::move(replacements), context, suffix); + return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{}, + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( - std::unordered_map> + absl::flat_hash_map> replacements, + absl::Span extra_parameters, HloCloneContext* context, const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { @@ -952,6 +967,12 @@ std::unique_ptr HloComputation::CloneWithReplacements( } std::vector> instructions; + // First add the extra parameters to 'instructions'. + for (const auto& instr : extra_parameters) { + CHECK_EQ(instr->opcode(), HloOpcode::kParameter) + << "Only parameter instructions are allowed in 'extra_parameters'"; + instructions.emplace_back(instr->Clone()); + } for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index c584e4c7ca5770533f28352b0df9dadd9dbe1860..0cb9caddd089011f3e9a4473995847dc966dd402 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -264,12 +263,6 @@ class HloComputation { // Return whether `*this` and `other` are functionally equivalent. bool operator==(const HloComputation& other) const; - // Generates a hash value of an HLO computation. Hash considers - // information on opcode, shape, operands, and typically a root instruction. - // This function returns the same hash value for equivalent HLO computations, - // with respect to HloInstruction::Identical() method. - uint64 Hash() const; - // Replaces old instruction with newly created instruction. Removes old // instruction from computation. Updates uses and root instruction. Status ReplaceWithNewInstruction( @@ -307,7 +300,7 @@ class HloComputation { // be a topological sort of all instructions in the computation. template Status AcceptOrdered(DfsHloVisitorBase* visitor, - const std::vector& order) const; + absl::Span order) const; // Same as Accept() above, but the visitor is given as a function. Status Accept(const std::function& visitor_func); @@ -329,11 +322,16 @@ class HloComputation { // that's not already in the computation, it's cloned and added to the new // computation. // + // 'extra_parameters' allows to specify additional parameters that should be + // added to the computation. + // // All relevant instructions are cloned, *including* unique_ptr in the // `replacements` map. std::unique_ptr CloneWithReplacements( - std::unordered_map> + absl::flat_hash_map> replacements, + absl::Span extra_parameters = {}, HloCloneContext* context = nullptr, const string& suffix = "clone"); // Convenience overloads for CloneWithReplacements. You want to do @@ -373,7 +371,7 @@ class HloComputation { // Returns a map from channel-id to directed dependencies of the channel // instructions. For send&recv pairs it means the send instruction and for - // cross-replica-sum the union of the dependencies for all participating + // all-reduce the union of the dependencies for all participating // instructions. using ChannelDependencyMap = absl::flat_hash_map>; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 8b50cfa9aed90091cfbedc1df902440ec9bf2a80..3b88e9745c27d6e1f2a46e5c83ac2e8bd8d05150 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -15,24 +15,28 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include #include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { +namespace m = match; using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; @@ -226,7 +230,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { : computation_(computation) {} Status DefaultAction(HloInstruction* hlo_instruction) override { - EXPECT_EQ(0, visited_set_.count(hlo_instruction)); + EXPECT_FALSE(visited_set_.contains(hlo_instruction)); visited_set_.insert(hlo_instruction); last_visited_ = hlo_instruction; return Status::OK(); @@ -239,7 +243,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { } HloComputation* computation_; - std::set visited_set_; + absl::flat_hash_set visited_set_; int64 finish_visit_calls_ = 0; HloInstruction* last_visited_ = nullptr; }; @@ -261,7 +265,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); - EXPECT_THAT(copy, op::Copy(constant)); + EXPECT_THAT(copy, GmockMatch(m::Copy(m::Op().Is(constant)))); } TEST_F(HloComputationTest, DeepCopyTuple) { @@ -278,8 +282,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto computation = module->AddEntryComputation(builder.Build()); auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); - EXPECT_THAT(tuple_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::Copy(op::GetTupleElement(tuple)))); + EXPECT_THAT(tuple_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))), + m::Copy(m::GetTupleElement(m::Op().Is(tuple)))))); EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index()); EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index()); } @@ -297,7 +302,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { ShapeTree indices_to_copy(constant->shape(), /*init_value=*/true); EXPECT_THAT(computation->DeepCopyInstruction(constant, &indices_to_copy) .ValueOrDie(), - op::Copy(constant)); + GmockMatch(m::Copy(m::Op().Is(constant)))); } { @@ -330,10 +335,11 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::Copy(op::GetTupleElement(tuple)))); - EXPECT_THAT(deep_copy, op::Tuple(copies_added.element({0}), - copies_added.element({1}))); + EXPECT_THAT(deep_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))) + .Is(copies_added.element({0})), + m::Copy(m::GetTupleElement(m::Op().Is(tuple))) + .Is(copies_added.element({1}))))); } { @@ -346,8 +352,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::GetTupleElement(tuple), - op::GetTupleElement(tuple))); + EXPECT_THAT(deep_copy, + GmockMatch(m::Tuple(m::GetTupleElement(m::Op().Is(tuple)), + m::GetTupleElement(m::Op().Is(tuple))))); EXPECT_TRUE(copies_added.element({}) == nullptr); EXPECT_TRUE(copies_added.element({0}) == nullptr); EXPECT_TRUE(copies_added.element({1}) == nullptr); @@ -363,8 +370,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::GetTupleElement(tuple))); + EXPECT_THAT(deep_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))), + m::GetTupleElement(m::Op().Is(tuple))))); EXPECT_TRUE(copies_added.element({}) == nullptr); EXPECT_TRUE(copies_added.element({0}) != nullptr); EXPECT_TRUE(copies_added.element({1}) == nullptr); @@ -381,7 +389,7 @@ TEST_F(HloComputationTest, DeepCopyToken) { auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); // No copy should be added. - EXPECT_THAT(copy, op::AfterAll()); + EXPECT_THAT(copy, GmockMatch(m::AfterAll())); } TEST_F(HloComputationTest, DeepCopyTokenTuple) { @@ -399,8 +407,9 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { // Only the array (second tuple element) should be copied. The token is passed // through transparently. - EXPECT_THAT(copy, op::Tuple(op::GetTupleElement(tuple), - op::Copy(op::GetTupleElement(tuple)))); + EXPECT_THAT(copy, GmockMatch(m::Tuple( + m::GetTupleElement(m::Op().Is(tuple)), + m::Copy(m::GetTupleElement(m::Op().Is(tuple)))))); } TEST_F(HloComputationTest, CycleDetection) { @@ -443,13 +452,15 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Negate(m::Op().Is(constant)))); EXPECT_EQ(negate, computation->root_instruction()); ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add)); EXPECT_EQ(2, computation->instruction_count()); - EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Negate(m::Op().Is(constant)))); EXPECT_EQ(negate, computation->root_instruction()); } @@ -484,6 +495,41 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } +TEST_F(HloComputationTest, CloneWithReplacements) { + auto builder = HloComputation::Builder(TestName()); + Shape r0s64 = ShapeUtil::MakeShape(S64, {}); + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + Shape r0u32 = ShapeUtil::MakeShape(U32, {}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "p.0.lhs")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs")); + auto param2 = + builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1")); + auto lt = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param0, param1)); + auto module = CreateNewVerifiedModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/lt)); + absl::flat_hash_map> + replacements; + replacements.emplace(param2, + HloInstruction::CreateParameter(2, r0s32, "p.1")); + auto param3 = HloInstruction::CreateParameter(3, r0u32, "p.2"); + std::vector extra_parameters{param3.get()}; + auto clone = computation->CloneWithReplacements(std::move(replacements), + extra_parameters); + ASSERT_EQ(clone->num_parameters(), 4); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(0)->shape(), r0f32_)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(1)->shape(), r0f32_)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(2)->shape(), r0s32)); + EXPECT_TRUE( + ShapeUtil::Equal(clone->parameter_instruction(3)->shape(), r0u32)); +} + TEST_F(HloComputationTest, Stringification) { const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); @@ -599,5 +645,28 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } +std::unique_ptr MakeAddNComputation(int n) { + auto builder = HloComputation::Builder("add_n"); + auto result = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "x_value")); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + for (int i = 0; i < n; ++i) { + result = builder.AddInstruction(HloInstruction::CreateBinary( + one->shape(), HloOpcode::kAdd, result, one)); + } + return builder.Build(); +} + +TEST_F(HloComputationTest, DeepEquality) { + auto computation_a = MakeAddNComputation(200000); + auto computation_b = MakeAddNComputation(200000); + EXPECT_TRUE(*computation_a == *computation_b); + + auto computation_c = MakeAddNComputation(199999); + EXPECT_FALSE(*computation_a == *computation_c); + EXPECT_FALSE(*computation_c == *computation_b); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 5e37883d3d8d5067bab873ac6b5f732e7360c5fa..e7ed858e8c5af83d08863d64a0aba162c75ed5cb 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -35,6 +35,34 @@ limitations under the License. namespace xla { +// Checks whether instr is or transitively contains an instruction that we +// shouldn't fold. +// +// Specifically, we don't fold kRng or kAfterAll instructions: +// +// - kRng is already marked as side-effecting and so is skipped elsewhere, but +// we check for it here. Even kRng weren't side-effecting and took an +// explicit seed, we *still* wouldn't want to constant-fold it, because the +// evaluator's handling of rng is not guaranteed to be identical to any +// particular backend's rng. +// +// - kAfterAll needs to be skipped because a kAfterAll op with no args can +// currently materialize a token "out of thin air". TODO(b/110532604): +// Remove this check once AfterAll requires at least one operand, in which +// case constant folding will be impossible. +static bool IsOrContainsIllegalInstr(const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kAfterAll || + instr->opcode() == HloOpcode::kRng) { + return true; + } + for (const HloComputation* c : instr->called_computations()) { + if (absl::c_any_of(c->instructions(), IsOrContainsIllegalInstr)) { + return true; + } + } + return false; +} + StatusOr HloConstantFolding::Run(HloModule* module) { // Limit the constant folding to 0 iterations to skip folding loops. This // retains the behavior from before while loop support in HloEvaluator and may @@ -52,25 +80,24 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Tuple, AfterAll operation. - // Tuple constants are not directly supported by any backends, hence - // folding Tuple is not useful and would in fact be expanded back into - // kTuple by Algebraic Simplifier. - // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one - // operand in which case constant folding will be impossible and this - // special case is not necessary. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kAfterAll) { - continue; - } // Skip instructions with non-constant operands. if (!hlo_query::AllOperandsAreConstants(*instruction)) { continue; } + // Don't fold Constant, Parameter, and Tuple instructions. Tuple + // constants are not directly supported by any backends, hence folding + // Tuple is not useful and would in fact be expanded back into kTuple by + // Algebraic Simplifier. + // + // (We do allow folding subcomputations that contain these instructions.) + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant || + instruction->opcode() == HloOpcode::kTuple) { + continue; + } + // Broadcasts dramatically increase the size of constants, which is often // detrimental to performance and memory capacity, so do not fold // broadcasts. @@ -79,12 +106,23 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } + // Check for instructions that we can't fold even if they appear inside of + // a subcomputation (e.g. a kCall). + if (IsOrContainsIllegalInstr(instruction)) { + continue; + } + + // Don't constant-fold side-effecting instructions or instructions which + // contain side-effecting instructions. + if (instruction->HasSideEffect()) { + continue; + } + // Don't constant fold unless it's a net positive or the output is small. - if (ShapeUtil::IsArray(instruction->shape())) { + if (instruction->shape().IsArray()) { int64 elements_in_removed_operands = 0; for (HloInstruction* operand : instruction->operands()) { - if (operand->user_count() == 1 && - ShapeUtil::IsArray(operand->shape())) { + if (operand->user_count() == 1 && operand->shape().IsArray()) { elements_in_removed_operands += ShapeUtil::ElementsIn(operand->shape()); } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 4f81dc94e577a63c09ae4019e5e8158252c712ce..4bdc980c9ac4fb79cde0242f407aea7057474b27 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -252,7 +252,7 @@ const char* const kConstantFoldLargePad = R"( HloModule ConstantFoldLargePad ENTRY r { - a = f32[1,1,1] constant(f32[1,1,1]{{{7}}}) + a = f32[1,1,1] constant({{{7}}}) b = f32[] constant(42) ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63 })"; @@ -268,5 +268,51 @@ TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { GmockMatch(m::Pad(m::Constant(), m::Constant()))); } +TEST_F(HloConstantFoldingTest, DontFoldSubcomputationContainingAfterAll) { + const char* const kModuleStr = R"( + HloModule test + + Fn { + tok = token[] after-all() + ROOT root = f32[10] iota(), iota_dimension=0 + } + + ENTRY entry { + ROOT call = f32[10] call(), to_apply=Fn + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + HloConstantFolding constant_folding; + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_folding, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(HloConstantFoldingTest, + DontFoldSubcomputationTransitivelyContainingRng) { + const char* const kModuleStr = R"( + HloModule test + + InnerFn { + c0 = f32[] constant(0) + c1 = f32[] constant(1) + ROOT rng = f32[10] rng(c0, c1), distribution=rng_uniform + } + + Fn { + ROOT fusion = f32[10] fusion(), kind=kLoop, calls=InnerFn + } + + ENTRY entry { + ROOT call = f32[10] call(), to_apply=Fn + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + HloConstantFolding constant_folding; + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&constant_folding, module.get())); + EXPECT_FALSE(result); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index df7d3826dbad1f264a5dc53312c062900155b0f6..1ee958114ebfa976cea72e901432575b7dc58321 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -237,24 +237,17 @@ Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) { Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); + const Shape& dot_shape = dot->shape(); const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); // Count of elements along the reduction dimension (last dimension for the // rhs). - int64 reduction_width = - lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); - // First divide by reduction width before multiplying by rhs elements to avoid - // overflow. - int64 fma_count; - if (reduction_width == 0) { - fma_count = 0; - } else { - fma_count = (ShapeUtil::ElementsIn(lhs_shape) / reduction_width) * - ShapeUtil::ElementsIn(rhs_shape); + int64 reduction_width = 1; + for (auto dim : dnums.lhs_contracting_dimensions()) { + reduction_width *= lhs_shape.dimensions(dim); } - - // We count an FMA operation as 2 floating point operations. - current_properties_[kFlopsKey] = kFmaFlops * fma_count; + // Each output elment requires reduction_widht FMA operations. + current_properties_[kFlopsKey] = + kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width; return Status::OK(); } @@ -292,7 +285,7 @@ Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { // does not need to be multiplied by the number of input tensors - that's // already "priced in" by the sub-computation doing more work. auto arg = reduce->operand(0); - auto output_shape = ShapeUtil::IsArray(reduce->shape()) + auto output_shape = reduce->shape().IsArray() ? reduce->shape() : reduce->shape().tuple_shapes(0); int64 reduction_count = @@ -531,7 +524,8 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { } const int64 fma_count = (input_feature / convolution->feature_group_count()) * - output_feature * batch * + output_feature * + (batch / convolution->batch_group_count()) * Product(valid_position_counts); current_properties_[kFlopsKey] = fma_count * kFmaFlops; return Status::OK(); @@ -539,7 +533,7 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { auto real_shape = - ShapeUtil::IsTuple(fft->operand(0)->shape()) + fft->operand(0)->shape().IsTuple() ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0) : fft->operand(0)->shape(); constexpr int kFmaPerComplexMul = 4; @@ -552,7 +546,7 @@ Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { return Status::OK(); } -Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { +Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. // @@ -561,7 +555,7 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { double flops = 0.0; ShapeUtil::ForEachSubshape(crs->shape(), [&](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { flops += ShapeUtil::ElementsIn(subshape); } }); @@ -577,6 +571,10 @@ Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { return Status::OK(); } +Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) { + return Status::OK(); +} + Status HloCostAnalysis::HandleRng(const HloInstruction* random) { // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 33983119c9b00a248c0e8dcc5815c6367192dca3..421786e20a3d9528ea76a44b3087ab2aed81d2b5 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,9 +71,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; - Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; + Status HandleReplicaId(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleRng(const HloInstruction* random) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index ff32faf298dd1f04c5b769f2a88f76a7a1e18ae7..4d42770ba784ba15fae9518b40a75d8a2f038e66 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/compiler/xla/statusor.h" @@ -157,6 +158,87 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { sizeof(float) * (10 * 5 + 5 * 30 + 10 * 30)); } +TEST_F(HloCostAnalysisTest, DotGeneral) { + XlaBuilder builder("matrix_multiply"); + auto lhs = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs"); + auto rhs = + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs"); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(0); + dnums.add_rhs_contracting_dimensions(1); + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 10 * 30)); +} + +TEST_F(HloCostAnalysisTest, DotGeneral2) { + XlaBuilder builder("matrix_multiply"); + auto lhs = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5, 5}), "lhs"); + auto rhs = + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 5, 30}), "rhs"); + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(2); + dnums.add_rhs_contracting_dimensions(0); + dnums.add_rhs_batch_dimensions(1); + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 * 5 + 5 * 5 * 30 + 5 * 10 * 30)); +} + +TEST_F(HloCostAnalysisTest, DotGeneral3) { + XlaBuilder builder("matrix_multiply"); + auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); + auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); + DotDimensionNumbers dnums; + DotGeneral(lhs, rhs, dnums); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Check the number of computations returned from the analysis (1500 FMAs). + EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5 * 5); + + EXPECT_EQ(analysis.transcendental_count(), 0); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (10 * 5 + 5 * 30 + 5 * 5 * 10 * 30)); +} + TEST_F(HloCostAnalysisTest, Map) { XlaBuilder builder("map"); auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in"); @@ -529,7 +611,8 @@ TEST_F(HloCostAnalysisTest, DynamicSlice) { // Test the analysis on a slice. XlaBuilder builder("dynamic-slice"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); - DynamicSlice(x, ConstantR1(&builder, {1}), {1}); + DynamicSlice(x, absl::Span({ConstantR0(&builder, 1)}), + {1}); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. @@ -545,7 +628,7 @@ TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { XlaBuilder builder("dynamic-update-slice"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); DynamicUpdateSlice(x, ConstantR1(&builder, {1.0}), - ConstantR1(&builder, {1})); + absl::Span({ConstantR0(&builder, 1)})); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index b2005d3c210d4ae7e3702cb9624c3ad98056984c..d56f673455f9129b72e9d85eaf8cbf03cfee4302 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -69,11 +70,11 @@ StatusOr MakeConvolveHlo( CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), feature_group_count, + lhs->shape(), rhs->shape(), feature_group_count, 1, window, dimension_numbers)); return computation->AddInstruction(HloInstruction::CreateConvolve( - convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers, - precision_config)); + convolve_shape, lhs, rhs, feature_group_count, 1, window, + dimension_numbers, precision_config)); } StatusOr MakeTransposeHlo(HloInstruction* operand, @@ -105,12 +106,26 @@ StatusOr MakeDynamicSliceHlo( absl::Span slice_sizes) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, start_indices->parent()); + int64 rank = start_indices->shape().dimensions(0); + std::vector scalar_start_indices; + for (int i = 0; i < rank; ++i) { + // TODO(b/118437727): Update callers to provide scalars directly. + auto slice = computation->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}), + start_indices, {i}, {i + 1}, {1})); + scalar_start_indices.push_back( + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {}), + slice))); + } + std::vector scalar_start_indices_shapes( + rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {})); TF_ASSIGN_OR_RETURN( Shape dynamic_slice_shape, ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), slice_sizes)); + operand->shape(), scalar_start_indices_shapes, slice_sizes)); return computation->AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, operand, start_indices, slice_sizes)); + dynamic_slice_shape, operand, scalar_start_indices, slice_sizes)); } StatusOr MakeDynamicUpdateSliceHlo( @@ -119,17 +134,31 @@ StatusOr MakeDynamicUpdateSliceHlo( HloComputation* computation = operand->parent(); CHECK_EQ(computation, update->parent()); CHECK_EQ(computation, start_indices->parent()); + int64 rank = start_indices->shape().dimensions(0); + std::vector scalar_start_indices; + for (int i = 0; i < rank; ++i) { + // TODO(b/118437727): Update callers to provide scalars directly. + auto slice = computation->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {1}), + start_indices, {i}, {i + 1}, {1})); + scalar_start_indices.push_back( + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(start_indices->shape().element_type(), {}), + slice))); + } + std::vector scalar_start_indices_shapes( + rank, ShapeUtil::MakeShape(start_indices->shape().element_type(), {})); TF_ASSIGN_OR_RETURN( Shape dynamic_update_slice_shape, ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); + operand->shape(), update->shape(), scalar_start_indices_shapes)); return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - dynamic_update_slice_shape, operand, update, start_indices)); + dynamic_update_slice_shape, operand, update, scalar_start_indices)); } -StatusOr MakeBroadcastHlo( - HloInstruction* operand, absl::Span broadcast_dimensions, - absl::Span result_shape_bounds) { +HloInstruction* MakeBroadcastHlo(HloInstruction* operand, + absl::Span broadcast_dimensions, + absl::Span result_shape_bounds) { HloComputation* computation = operand->parent(); Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_bounds); @@ -189,8 +218,7 @@ StatusOr MakeMapHlo(absl::Span operands, for (const HloInstruction* operand : operands) { CHECK_EQ(computation, operand->parent()); operand_shapes.push_back(&operand->shape()); - max_operand_rank = - std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + max_operand_rank = std::max(max_operand_rank, operand->shape().rank()); } std::vector map_dims(max_operand_rank); std::iota(map_dims.begin(), map_dims.end(), 0); @@ -207,7 +235,7 @@ StatusOr MakeReduceHlo(HloInstruction* operand, HloOpcode binary_opcode, HloModule* module) { DCHECK_NE(nullptr, module); - std::vector all_dims(ShapeUtil::Rank(operand->shape())); + std::vector all_dims(operand->shape().rank()); std::iota(all_dims.begin(), all_dims.end(), 0); auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); @@ -366,9 +394,9 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, return MakePadHlo(operand, zero, padding_config); } -StatusOr BroadcastZeros( - HloComputation* computation, PrimitiveType element_type, - absl::Span broadcast_dimensions) { +HloInstruction* BroadcastZeros(HloComputation* computation, + PrimitiveType element_type, + absl::Span broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 8e5ddbbd503a501bd493aec43a2ccd4db883ef0c..1c3174e9c89c16cb11589e7c0235bdf13eae6b85 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -82,9 +82,9 @@ StatusOr MakeDynamicUpdateSliceHlo( // Creates a broadcast HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeBroadcastHlo( - HloInstruction* operand, absl::Span broadcast_dimensions, - absl::Span result_shape_bounds); +HloInstruction* MakeBroadcastHlo(HloInstruction* operand, + absl::Span broadcast_dimensions, + absl::Span result_shape_bounds); // Creates a GetTupleElement HLO instruction and adds it to the computation // containing `operand`. @@ -198,9 +198,9 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, // Broadcasts a zero value of type `element_type` into a tensor with element // type `element_type` and dimension bounds `broadcast_dimensions`. The // broadcast instruction is emitted into `computation`. -StatusOr BroadcastZeros( - HloComputation* computation, PrimitiveType element_type, - absl::Span broadcast_dimensions); +HloInstruction* BroadcastZeros(HloComputation* computation, + PrimitiveType element_type, + absl::Span broadcast_dimensions); // Creates a HLO computation that takes arguments of type `domain` and produces // a value of type `range`. diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index aaa9ec60eb3c4e0159ed40b37d772e0973d306ec..6025e6a77941369f75ebaa98bdf0979669b3a03c 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -56,9 +56,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { entry_computation->set_root_instruction(first_1_dims_collapsed); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({3, 4})})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({3, 4})})); CHECK_EQ(result_literal, LiteralUtil::CreateR1({3, 4})); } @@ -77,10 +77,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate( - *module, - {LiteralUtil::CreateR3( - {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, + {{-1, -2}, {-3, -4}, {-5, -6}}})})); CHECK_EQ(result_literal, LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); @@ -101,8 +100,7 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, - {LiteralUtil::CreateR1({9, 10})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({9, 10})})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9, 10}})); } @@ -121,8 +119,7 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, - {LiteralUtil::CreateR1({9, 10})})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({9, 10})})); CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{9, 10}}})); } @@ -141,7 +138,7 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9}})); } @@ -160,8 +157,8 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } @@ -180,9 +177,9 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { entry_computation->set_root_instruction(zero_padded_param); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR1({3, 4})})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR1({3, 4})})); CHECK_EQ(result_literal, LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } @@ -194,15 +191,14 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { /*output_shape_dims=*/{2, 2}, ¶m, &entry_computation); - TF_ASSERT_OK_AND_ASSIGN( - HloInstruction * zeros, - BroadcastZeros(module->entry_computation(), S32, {2, 2})); + HloInstruction* zeros = + BroadcastZeros(module->entry_computation(), S32, {2, 2}); entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( Literal result_literal, - evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } @@ -214,15 +210,14 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { /*output_shape_dims=*/{2, 2}, ¶m, &entry_computation); - TF_ASSERT_OK_AND_ASSIGN( - HloInstruction * zeros, - BroadcastZeros(module->entry_computation(), F32, {2, 2})); + HloInstruction* zeros = + BroadcastZeros(module->entry_computation(), F32, {2, 2}); entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - evaluator.Evaluate( - *module, {LiteralUtil::CreateR0(0.0f)})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0.0f)})); CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 3ed3d3c11c71dc534f193ba3ffb556b0eb0c80e4..3144a84805454488f417391f40ed6b9e9facc752 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -107,7 +107,7 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( return false; } } - if (!visited.count(user)) { + if (!visited.contains(user)) { stack.push_back(user); } } @@ -190,7 +190,7 @@ string HloDataflowAnalysis::ToString() const { for (const HloComputation* computation : module_.computations()) { for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { GetInstructionValueSet(instruction) .ForEachElement([this, &instruction, &out]( const ShapeIndex& index, @@ -256,7 +256,7 @@ bool HloDataflowAnalysis::Phi( input_value_ids.push_back(value->id()); } } - std::sort(input_value_ids.begin(), input_value_ids.end()); + absl::c_sort(input_value_ids); input_value_ids.erase( std::unique(input_value_ids.begin(), input_value_ids.end()), input_value_ids.end()); @@ -271,8 +271,7 @@ bool HloDataflowAnalysis::Phi( if (current_value_defined_here) { VLOG(5) << "current_value_defined_here: " << current_value->ToString(); CHECK(current_value->is_phi()); - auto it = std::find(input_value_ids.begin(), input_value_ids.end(), - current_value->id()); + auto it = absl::c_find(input_value_ids, current_value->id()); if (it != input_value_ids.end()) { input_value_ids.erase(it); } @@ -921,8 +920,7 @@ StatusOr> HloDataflowAnalysis::Run( for (auto& pair : dataflow_analysis->values_) { dataflow_analysis->values_vector_.push_back(&pair.second); } - std::sort(dataflow_analysis->values_vector_.begin(), - dataflow_analysis->values_vector_.end(), HloValue::IdLessThan); + absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan); TF_DCHECK_OK(dataflow_analysis->Verify()); @@ -937,9 +935,7 @@ Status HloDataflowAnalysis::Verify() const { for (const HloValue* value : values()) { for (const HloPosition& position : value->positions()) { const HloValueSet& value_set = GetValueSet(position); - TF_RET_CHECK(std::find(value_set.values().begin(), - value_set.values().end(), - value) != value_set.values().end()) + TF_RET_CHECK(absl::c_linear_search(value_set.values(), value)) << "Value set at position " << position << " does not contain value " << value->ToShortString(); } @@ -954,9 +950,7 @@ Status HloDataflowAnalysis::Verify() const { const HloValueSet& value_set = pair.second; const HloPosition position{instruction, index}; for (const HloValue* value : value_set.values()) { - TF_RET_CHECK(std::find(value->positions().begin(), - value->positions().end(), - position) != value->positions().end()) + TF_RET_CHECK(absl::c_linear_search(value->positions(), position)) << "Value set at position " << position << " unexpectedly contains value " << value->ToShortString(); } @@ -1041,11 +1035,10 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Check if one operand of kAdd fused root is kDot or kConvolution. auto* add = user->fused_expression_root(); auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot; - }); + absl::c_find_if(add->operands(), [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); if (add_operand_it == add->operands().end()) { return false; } @@ -1100,16 +1093,15 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // *) The root instruction of the called computation is element-wise on // 'operand'. const bool found_caller_use = - std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + absl::c_find_if(uses, [user](const HloUse& use) { return use.instruction == user; }) != uses.end(); auto* callee_root = user->to_apply()->root_instruction(); const bool found_elementwise_callee_use = - std::find_if( - uses.begin(), uses.end(), [callee_root](const HloUse& use) { - return use.instruction == callee_root && - callee_root->IsElementwiseOnOperand(use.operand_number); - }) != uses.end(); + absl::c_find_if(uses, [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index f7a1f19a6f52befd58a405d0e406d7d0d37a8e57..4a7c4963b7b399e625da907b3810c42df7ee2bd3 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -73,8 +73,8 @@ class HloDataflowAnalysisTest : public HloTestBase, bool InstructionsMayInterfere(const HloOrdering& ordering, const HloInstruction* a, const HloInstruction* b) { - EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); - EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); + EXPECT_FALSE(a->shape().IsTuple()); + EXPECT_FALSE(b->shape().IsTuple()); return ordering.MayInterfere(analysis_->GetValueDefinedAt(a), analysis_->GetValueDefinedAt(b), *analysis_); } @@ -1882,8 +1882,8 @@ TEST_P(HloDataflowAnalysisTest, AddDependency) { HloModule AddDependency ENTRY %AddDependency (p: f32[3]) -> f32[3] { %p = f32[3] parameter(0) - %token = token[] after-all() - ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token) + %token0 = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0) } )"; TF_ASSERT_OK_AND_ASSIGN( @@ -1901,9 +1901,9 @@ ENTRY %AddDependency (p: f32[3]) -> f32[3] { EXPECT_FALSE(analysis->ValueIsDefinedAt(root)); } -INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, - HloDataflowAnalysisTest, - ::testing::Values(false, true)); +INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation, + HloDataflowAnalysisTest, + ::testing::Values(false, true)); class HloDataflowAnalysisTestBase : public HloTestBase { protected: @@ -1970,12 +1970,13 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2012,12 +2013,13 @@ TEST_F(DoesNotUseOperandBufferTest, IndirectUses) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2150,17 +2152,17 @@ TEST_F(CanShareOperandBufferWithUserTest, auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "param0")); - auto index = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0}))); - auto ds = builder.AddInstruction( - HloInstruction::CreateDynamicSlice(slice_shape, param, index, {1, 2, 2})); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, param, {zero, zero}, {1, 2, 2})); - auto dus = builder.AddInstruction( - HloInstruction::CreateDynamicUpdateSlice(data_shape, param, ds, index)); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, param, ds, {zero, zero})); BuildModule(builder.Build()); auto fusion = computation_->CreateFusionInstruction( - {dus, ds, index}, HloInstruction::FusionKind::kLoop); + {dus, ds, zero}, HloInstruction::FusionKind::kLoop); RunAnalysis(); EXPECT_TRUE( @@ -2219,12 +2221,13 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, + std::initializer_list({starts}))); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -2259,12 +2262,13 @@ TEST_F(CanShareOperandBufferWithUserTest, // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape_bf16, convert1, update, starts)); + data_shape_bf16, convert1, update, + std::initializer_list({starts}))); auto convert2 = builder.AddInstruction( HloInstruction::CreateConvert(data_shape, dynamic_update_slice)); @@ -2290,10 +2294,13 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { HloInstruction::CreateParameter(0, data_shape, "data")); auto update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); - auto starts = builder.AddInstruction( - HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto start0 = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "start0")); + auto start1 = builder.AddInstruction( + HloInstruction::CreateParameter(3, starts_shape, "start1")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); + data_shape, data, update, {start0, start1})); BuildModuleAndRunAnalysis(builder.Build()); @@ -2304,7 +2311,9 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { EXPECT_FALSE( dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); EXPECT_FALSE( - dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); + dataflow_analysis_->CanShareOperandBufferWithUser(start0, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(start1, {}, dus, {})); } TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 7d35e251ca21951036336ff1a1eb4aabc87bc5ca..a5a11f09cf4f857b992e5ede3a9dbc5a937ce722 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -65,7 +66,7 @@ StatusOr HloDCE::Run(HloModule* module) { // Now DCE HloComputations. First, collect the computations that are // referenced by some remaining instruction. - std::unordered_set live_computations; + absl::flat_hash_set live_computations; if (HloComputation* entry_computation = module->entry_computation()) { live_computations.insert(entry_computation); } @@ -79,7 +80,7 @@ StatusOr HloDCE::Run(HloModule* module) { // Remove dead computations. for (auto* computation : module->MakeComputationPostOrder()) { - if (live_computations.count(computation) == 0) { + if (!live_computations.contains(computation)) { TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); changed = true; } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 1fa4259a3e42286cbc911907eea563e6ca6f8611..b5d72b386f89568cc3066b2e497be98428d1ed0c 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -43,9 +43,7 @@ class HloDceTest : public HloTestBase { // Returns whether the given instruction exists in the given computation. bool HasInstruction(const HloComputation& computation, const HloInstruction* instruction) { - return std::find(computation.instructions().begin(), - computation.instructions().end(), - instruction) != computation.instructions().end(); + return absl::c_linear_search(computation.instructions(), instruction); } }; diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index c6d02f9f67bb599e496d20fc2acf2e627ed54438..7cdb7f6bdf26241cda4fabbb5ccaf6e6f7de39ce 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -230,10 +230,10 @@ HloDomainMap::MakeNonDomainInstructions( } } // sort instructions according to instructions_order - std::sort(instructions.begin(), instructions.end(), - [&instructions_order](HloInstruction* a, HloInstruction* b) { - return instructions_order.at(a) < instructions_order.at(b); - }); + absl::c_sort(instructions, + [&instructions_order](HloInstruction* a, HloInstruction* b) { + return instructions_order.at(a) < instructions_order.at(b); + }); return instructions; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index acdb42128e3d9a1fb912a466c9c2c3cbbe3d3f83..fd4fb0246d8d42ab7329c05dc23e386303cdce3c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -195,10 +195,10 @@ HloModule Module ENTRY entry { p0 = (f32[4]) parameter(0) a = f32[4] get-tuple-element(p0), index=0 - token = token[] after-all() - b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0} + token0 = token[] after-all() + b = (f32[4], u32[], token[]) send(a, token0), channel_id=1, sharding={maximal device=0} c = token[] send-done(b), channel_id=1, sharding={maximal device=0} - d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[], token[]) recv(token0), channel_id=2, sharding={maximal device=0} e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0} e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0} f = f32[4] add(a, e_element) @@ -235,12 +235,12 @@ TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { HloModule Module ENTRY entry { - token = token[] after-all(), sharding={maximal device=-1} - a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1} + token0 = token[] after-all(), sharding={maximal device=-1} + a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=-1} b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1} b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1} c = f32[4] add(b_element, b_element), sharding={maximal device=-1} - d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1} + d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=-1} ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1} } )"; @@ -259,12 +259,12 @@ TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { HloModule Module ENTRY entry { - token = token[] after-all(), sharding={maximal device=0} - a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0} + token0 = token[] after-all(), sharding={maximal device=0} + a = (f32[4], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=0} b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0} b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0} c = f32[4] add(b_element, b_element) - d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0} + d = (f32[4], u32[], token[]) send(c, token0), channel_id=2, sharding={maximal device=0} ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0} } )"; @@ -344,8 +344,8 @@ TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { HloModule Module ENTRY entry { - token = token[] after-all() - infeed = ((f32[4], f32[4]), token[]) infeed(token), + token0 = token[] after-all() + infeed = ((f32[4], f32[4]), token[]) infeed(token0), sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0, sharding={{maximal device=1}, {maximal device=0}} diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index 72006e17e7e7ec09b62e88d05b695ec9f4c49647..9b0f2b2a0f4dd5d1d1191e9ab0637cc3034b50da 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -68,7 +68,7 @@ Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type, std::vector new_tuple_subshapes; for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { Shape subshape = ShapeUtil::GetTupleElementShape(shape, i); - CHECK(!ShapeUtil::IsTuple(subshape)); + CHECK(!subshape.IsTuple()); if (subshape.element_type() == from_type) { subshape = ShapeUtil::ChangeElementType(subshape, to_type); } @@ -92,7 +92,7 @@ HloInstruction* ConvertTupleElements(HloInstruction* hlo, HloInstruction* element = computation->AddInstruction( HloInstruction::CreateGetTupleElement(ele_shape, hlo, i)); const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i); - CHECK(!ShapeUtil::IsTuple(ele_shape)); + CHECK(!ele_shape.IsTuple()); if (ele_shape.element_type() != to_ele_shape.element_type()) { element = computation->AddInstruction( HloInstruction::CreateConvert(to_ele_shape, element)); @@ -141,10 +141,9 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops with embedded computations where it suffices to convert // the embedded computations instead of converting the ops themselves. if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || - opcode == HloOpcode::kCrossReplicaSum || - opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || - opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || - opcode == HloOpcode::kScatter || + opcode == HloOpcode::kAllReduce || opcode == HloOpcode::kFusion || + opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || + opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kScatter || opcode == HloOpcode::kSelectAndScatter || opcode == HloOpcode::kConditional) { continue; @@ -191,7 +190,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); new_hlo = ToElementType(new_hlo, eliminate_type_); - } else if (ShapeUtil::IsTuple(hlo->shape())) { + } else if (hlo->shape().IsTuple()) { Shape old_shape = hlo->shape(); Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, replace_with_type_); diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index c170e36c73ad2bef830e528de3ec72d38683d888..5b633784e2f306290ca6c096f67c657be1f188c8 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -28,15 +28,7 @@ using ::testing::Eq; using ::testing::Not; using ::testing::ResultOf; -class HloElementTypeConverterTest : public HloTestBase { - public: - std::unique_ptr CreateModuleFromHloString( - const string& hlo_string) { - return HloRunner::CreateModuleFromString(hlo_string, - GetDebugOptionsForTest()) - .ValueOrDie(); - } -}; +using HloElementTypeConverterTest = HloTestBase; TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) { const string& hlo_string = R"( @@ -47,7 +39,7 @@ TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) { custom_call_target="foo" } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_FALSE(converted); @@ -57,13 +49,13 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { const string& hlo_string = R"( HloModule InfeedOutfeed ENTRY RoundTrip16MiBR1.v2 { - token = token[] after-all() - infeed = (bf16[4]{0}, token[]) infeed(token) + token0 = token[] after-all() + infeed = (bf16[4]{0}, token[]) infeed(token0) ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) + outfeed = token[] outfeed(infeed.data, token0) } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_FALSE(converted); @@ -73,17 +65,16 @@ TEST_F(HloElementTypeConverterTest, OperationsInNestedTuplesConverted) { const string& hlo_string = R"( HloModule NestedTuples ENTRY NestedTuples.v5 { - constant.4 = bf16[] constant(42) constant.2 = f32[2]{0} constant({1, 2}) - constant.3 = bf16[] constant(42) - add = bf16[] add(constant.2, constant.3) - tuple = (f32[2]{0}, bf16[]) tuple(constant.2, add) + constant.3 = bf16[2]{0} constant({42, 42}) + add = bf16[2]{0} add(constant.2, constant.3) + tuple = (f32[2]{0}, bf16[2]{0}) tuple(constant.2, add) constant.5 = bf16[2]{0} constant({22, 44}) - ROOT tuple.1 = ((f32[2]{0}, bf16[]), bf16[2]{0}) tuple(tuple, constant.5) + ROOT tuple.1 = ((f32[2]{0}, bf16[2]{0}), bf16[2]{0}) tuple(tuple, constant.5) } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -96,13 +87,13 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { const string& hlo_string = R"( HloModule BatchNormGrad ENTRY BatchNormGrad.v6 { - constant.4 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + constant.4 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } }, { /*i0=1*/ { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } } }) constant.5 = bf16[2]{0} constant({1, 1}) constant.6 = bf16[2]{0} constant({0, 0}) constant.7 = bf16[2]{0} constant({1, 1}) - constant.8 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + constant.8 = bf16[2,2,2,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} } }, { /*i0=1*/ { /*i1=0*/ {5}, {6} }, { /*i1=1*/ {7}, {8} } } }) ROOT batch-norm-grad = (bf16[2,2,2,1]{3,2,1,0}, bf16[2]{0}, bf16[2]{0}) @@ -111,7 +102,7 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -135,7 +126,7 @@ ENTRY main { ROOT rng = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); EXPECT_TRUE(converted); @@ -161,7 +152,7 @@ ENTRY main { ROOT rng1 = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), control-predecessors={%rng0}, distribution=rng_uniform } )"; - auto module = CreateModuleFromHloString(hlo_string); + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); HloElementTypeConverter type_converter(BF16, F32); TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 51a3fba1768aaf219b78ddc09a1c526448389d9e..56a1b6f43945adae18313546432b959f66a32dcf 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include #include +#include #include #include -#include #include #include "absl/algorithm/container.h" @@ -29,16 +29,17 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -134,8 +135,44 @@ StatusOr Compare(const Shape& shape, HloOpcode opcode, return std::move(result); } +template <> +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](complex128 lhs_el, complex128 rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](complex128 lhs_el, complex128 rhs_el) { + return lhs_el != rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + Literal result(shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); + + return std::move(result); +} + } // namespace +// Note that unsupported types by the typed visitor does not necessarily imply +// the non-typed HloEvaluator (parent evaluator) would not support them either +// in the type-agnostic handler. For e.g., HandleGetTupleElement in the parent +// type-agnostic evaluator will be able to accept Tuple primitive type, whereas +// HloEvaluatorTypedVisitor cannot. HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { typed_visitors_[PRED] = @@ -143,22 +180,14 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) typed_visitors_[U8] = absl::make_unique>(this); typed_visitors_[U16] = - absl::make_unique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "U16."); - }); + absl::make_unique>(this); typed_visitors_[U32] = absl::make_unique>(this); typed_visitors_[U64] = absl::make_unique>(this); typed_visitors_[S8] = absl::make_unique>(this); typed_visitors_[S16] = - absl::make_unique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "S16."); - }); + absl::make_unique>(this); typed_visitors_[S32] = absl::make_unique>(this); typed_visitors_[S64] = @@ -171,6 +200,8 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) absl::make_unique>(this); typed_visitors_[C64] = absl::make_unique>(this); + typed_visitors_[C128] = + absl::make_unique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all @@ -196,65 +227,30 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) }); } -template -StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals) { - XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); - - evaluated_.clear(); - arg_literals_.clear(); - for (const auto& literal_ptr : arg_literals) { - arg_literals_.push_back(&*literal_ptr); - } - - TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); - - return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) - .Clone(); -} - -template <> -StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal_ptr : arg_literals) { - arg_literal_ptrs.push_back(&literal_ptr); - } - return Evaluate(module, arg_literal_ptrs); -} - -template StatusOr HloEvaluator::Evaluate( const HloComputation& computation, - absl::Span arg_literals) { + absl::Span arg_literals) { CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); - evaluated_.clear(); - arg_literals_.clear(); - for (const auto& literal_ptr : arg_literals) { - arg_literals_.push_back(&*literal_ptr); + if (arg_literals.size() != computation.num_parameters()) { + return InvalidArgument( + "Expected %d argument%s, but got %d.", computation.num_parameters(), + computation.num_parameters() == 1 ? "" : "s", arg_literals.size()); } - - TF_RETURN_IF_ERROR(computation.Accept(this)); - return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); -} - -template <> -StatusOr HloEvaluator::Evaluate( - const HloComputation& computation, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal_ptr : arg_literals) { - arg_literal_ptrs.push_back(&literal_ptr); + for (int64 i = 0; i < arg_literals.size(); ++i) { + const auto& computation_shape = + computation.parameter_instruction(i)->shape(); + const auto& arg_shape = arg_literals[i]->shape(); + if (!ShapeUtil::Equal(computation_shape, arg_shape)) { + return InvalidArgument( + "Shape mismatch at parameter %d. Computation expected %s, but arg " + "was %s.", + i, ShapeUtil::HumanStringWithLayout(computation_shape), + ShapeUtil::HumanString(arg_shape)); + } } - return Evaluate(computation, arg_literal_ptrs); -} - -template -StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals) { - TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); evaluated_.clear(); arg_literals_.clear(); @@ -262,33 +258,20 @@ StatusOr HloEvaluator::Evaluate( arg_literals_.push_back(&*literal_ptr); } - // Evaluate operands of Parameter type against the input literals which - // caches the evaluated literal results. - for (const auto operand : instruction->operands()) { - if (operand->opcode() == HloOpcode::kParameter) { - const Literal* input_literal = arg_literals_[operand->parameter_number()]; - VLOG(2) << "Parameter operand evaluated to: " - << input_literal->ToString(); - TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - - evaluated_[operand] = input_literal->Clone(); - } + // Re-seed RNG, either from the configuration's seed or a monotonic + // per-evaluator seed (which prevents two evaluators from returning the same + // random sequence). + if (computation.parent()->config().seed()) { + seed_ = computation.parent()->config().seed(); + } else { + // Start global_seed at a (true) random value. + static std::atomic global_seed{std::random_device()()}; + seed_ = global_seed.fetch_add(1); } + engine_.seed(seed_); - TF_RETURN_IF_ERROR(Preprocess(instruction)); - TF_RETURN_IF_ERROR(instruction->Visit(this)); - TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).Clone(); -} - -template <> -StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals) { - std::vector arg_literal_ptrs; - for (const auto& literal : arg_literals) { - arg_literal_ptrs.push_back(&literal); - } - return Evaluate(instruction, arg_literal_ptrs); + TF_RETURN_IF_ERROR(computation.Accept(this)); + return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); } StatusOr HloEvaluator::Evaluate(HloInstruction* instruction) { @@ -396,16 +379,55 @@ StatusOr HloEvaluator::EvaluateDotOp( return Evaluate(cloned_instruction.get()); } +Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) { + const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0)); + Literal result(bitcast->shape()); + TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes()); + memcpy(result.untyped_data(), operand_literal.untyped_data(), + operand_literal.size_bytes()); + evaluated_[bitcast] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleGetDimensionSize( + HloInstruction* get_dimension_size) { + HloInstruction* operand = get_dimension_size->mutable_operand(0); + int64 dim = get_dimension_size->dimension(); + if (dynamic_dimension_inference_ == nullptr) { + return InvalidArgument( + "Evaluator cannot evaluate get_dimension_size without " + "set_dynamic_dimension_inference."); + } + HloInstruction* dynamic_size = + dynamic_dimension_inference_->GetDynamicSize(operand, {}, dim); + if (dynamic_size != nullptr) { + evaluated_[get_dimension_size] = + GetEvaluatedLiteralFor(dynamic_size).Clone(); + return Status::OK(); + } + + const Shape& shape = get_dimension_size->operand(0)->shape(); + Literal output(ShapeUtil::MakeShape(U32, {})); + output.PopulateWithValue( + static_cast(shape.dimensions(get_dimension_size->dimension()))); + evaluated_[get_dimension_size] = std::move(output); + return Status::OK(); +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { + // Nothing to do other than sanity checks. Parameters' values are stored in + // arg_literals_. CHECK_LT(parameter->parameter_number(), arg_literals_.size()); + +#ifndef NDEBUG const Literal* input_literal = arg_literals_[parameter->parameter_number()]; VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())) << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape()) << ", but input literal shape is: " << ShapeUtil::HumanString(input_literal->shape()); +#endif - evaluated_[parameter] = input_literal->Clone(); return Status::OK(); } @@ -430,8 +452,8 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); - CHECK(ShapeUtil::IsArray(reference_shape)); - const int64 rank = ShapeUtil::Rank(reference_shape); + CHECK(reference_shape.IsArray()); + const int64 rank = reference_shape.rank(); const int64 concat_dim = concatenate->dimensions()[0]; CHECK_GE(concat_dim, 0); CHECK_LT(concat_dim, rank); @@ -441,7 +463,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (int64 i = 1; i < operands.size(); ++i) { const Shape& operand_shape = operands[i]->shape(); - CHECK(ShapeUtil::IsArray(operand_shape)); + CHECK(operand_shape.IsArray()); // Accumulate the concat dimension from all tensors taking part to the // operation. concat_dimensions[concat_dim] += @@ -468,15 +490,52 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { auto operand = is_finite->operand(0); - if (!ShapeUtil::ElementIsFloating(operand->shape())) { - return InvalidArgument( - "expected element type in shape to be float for IsFinite op, got: %s", - PrimitiveType_Name(operand->shape().element_type())); - } + auto elem_ty = operand->shape().element_type(); + switch (elem_ty) { + case PRED: + case TUPLE: + case OPAQUE: + case TOKEN: + case S8: + case S16: + case S32: + case S64: + case U8: + case U16: + case U32: + case U64: + case C64: + case C128: + // Explicitly enumerate all types in this switch so that when we add a new + // type, we'll get a compile error here. + case PRIMITIVE_TYPE_INVALID: + case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_: + case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_: + return InvalidArgument( + "expected element type in shape to be floating point, but " + "got: %s", + PrimitiveType_Name(elem_ty)); - switch (operand->shape().element_type()) { - case F16: - return Unimplemented("unhandled primitive type: F16."); + case F16: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](Eigen::half elem_operand) { + return std::isfinite(static_cast(elem_operand)); + }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } + case BF16: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](bfloat16 elem_operand) { + return std::isfinite(static_cast(elem_operand)); + }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } case F32: { auto result_or = ElementWiseUnaryOpImpl( is_finite, @@ -493,9 +552,6 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); break; } - default: - LOG(FATAL) << "HandleIsFinite: unknown/unhandled primitive type: " - << PrimitiveType_Name(operand->shape().element_type()); } return Status::OK(); @@ -518,6 +574,13 @@ Status HloEvaluator::HandleReal(HloInstruction* real) { TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); break; } + case C128: { + auto result_or = ElementWiseUnaryOpImpl( + real, [](complex128 elem_operand) { return std::real(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } case F16: { auto result_or = ElementWiseUnaryOpImpl( real, [](Eigen::half elem_operand) { return elem_operand; }, @@ -548,11 +611,61 @@ Status HloEvaluator::HandleReal(HloInstruction* real) { } Status HloEvaluator::HandleImag(HloInstruction* imag) { - auto result_or = ElementWiseUnaryOpImpl( - imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, - GetEvaluatedLiteralFor(imag->operand(0))); + auto operand = imag->operand(0); + switch (operand->shape().element_type()) { + case C64: { + auto result_or = ElementWiseUnaryOpImpl( + imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); - TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + break; + } + case C128: { + auto result_or = ElementWiseUnaryOpImpl( + imag, [](complex128 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); + + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "HandleImag: unknown/unhandled primitive type: " + << PrimitiveType_Name(operand->shape().element_type()); + } + + return Status::OK(); +} + +Status HloEvaluator::HandleComplex(HloInstruction* complex) { + const Literal& real = GetEvaluatedLiteralFor(complex->operand(0)); + const Literal& imag = GetEvaluatedLiteralFor(complex->operand(1)); + TF_RET_CHECK(ShapeUtil::Compatible(real.shape(), imag.shape())); + + Literal result(complex->shape()); + switch (complex->shape().element_type()) { + case C64: { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return std::complex(real.Get(multi_index), + imag.Get(multi_index)); + })); + break; + } + case C128: { + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return std::complex(real.Get(multi_index), + imag.Get(multi_index)); + })); + break; + } + default: + LOG(FATAL) << "HandleComplex: unknown/unhandled primitive type: " + << PrimitiveType_Name(complex->shape().element_type()); + } + + evaluated_[complex] = std::move(result); return Status::OK(); } @@ -589,8 +702,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; - case U16: - return Unimplemented("unhandled primitive type: U16."); + case U16: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; case U32: { TF_ASSIGN_OR_RETURN( evaluated_[compare], @@ -606,8 +722,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; - case S16: - return Unimplemented("unhandled primitive type: S16."); + case S16: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; case S32: { TF_ASSIGN_OR_RETURN( evaluated_[compare], @@ -618,8 +737,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; - case F16: - return Unimplemented("unhandled primitive type: F16."); + case F16: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; case BF16: { TF_ASSIGN_OR_RETURN(evaluated_[compare], Compare(compare->shape(), opcode, @@ -640,6 +762,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; + case C128: { + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), opcode, + lhs_literal, rhs_literal)); + } break; default: LOG(FATAL) << "HandleCompare: unknown primitive type: " << PrimitiveType_Name(lhs->shape().element_type()); @@ -1021,11 +1148,9 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0)); - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand.shape())) + TF_RET_CHECK(broadcast->dimensions().size() == operand.shape().rank()) << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand.shape()); + << " and rank of operand_to_broadcast is: " << operand.shape().rank(); // Checks that operand's dimensions are the same as the broadcast's // dimensions along the dimensions to be broadcasted. for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { @@ -1098,9 +1223,10 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } HloEvaluator embedded_evaluator; - Literal result = - embedded_evaluator.Evaluate(*computation, arg_literals) - .ConsumeValueOrDie(); + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); + Literal result = embedded_evaluator.Evaluate(*computation, arg_literals) + .ConsumeValueOrDie(); evaluated_[call] = std::move(result); return Status::OK(); @@ -1116,7 +1242,9 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { fusion->fused_instructions_computation()->Clone( /*suffix=*/"clone_with_layout", &context); for (auto* instruction : cloned_fused_computation->instructions()) { - LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); + if (!LayoutUtil::HasLayout(instruction->shape())) { + LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); + } } auto readded_computation = empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation)); @@ -1130,9 +1258,10 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { } HloEvaluator embedded_evaluator; + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); Literal result = - embedded_evaluator - .Evaluate(*readded_computation, arg_literals) + embedded_evaluator.Evaluate(*readded_computation, arg_literals) .ConsumeValueOrDie(); evaluated_[fusion] = std::move(result); @@ -1150,16 +1279,16 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* false_computation = conditional->false_computation(); HloEvaluator embedded_evaluator; + embedded_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); Literal result; if (pred.Get({})) { - result = embedded_evaluator - .Evaluate(*true_computation, - {&true_computation_arg}) - .ConsumeValueOrDie(); + result = + embedded_evaluator.Evaluate(*true_computation, {&true_computation_arg}) + .ConsumeValueOrDie(); } else { result = embedded_evaluator - .Evaluate(*false_computation, - {&false_computation_arg}) + .Evaluate(*false_computation, {&false_computation_arg}) .ConsumeValueOrDie(); } @@ -1206,18 +1335,21 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { bool keep_going = true; int64 iteration_count = 0; HloEvaluator cond_evaluator(max_loop_iterations_); + cond_evaluator.set_dynamic_dimension_inference(dynamic_dimension_inference_); HloEvaluator loop_body_evaluator(max_loop_iterations_); + loop_body_evaluator.set_dynamic_dimension_inference( + dynamic_dimension_inference_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", while_hlo->name(), max_loop_iterations_); } TF_ASSIGN_OR_RETURN(auto cond_val, - cond_evaluator.Evaluate(*cond_comp, {&lcv})); + cond_evaluator.Evaluate(*cond_comp, {&lcv})); keep_going = cond_val.GetFirstElement(); if (keep_going) { - TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( - *body_comp, {&lcv})); + TF_ASSIGN_OR_RETURN(auto body_val, + loop_body_evaluator.Evaluate(*body_comp, {&lcv})); VLOG(3) << "Loop iteration result: " << body_val.ToString(); lcv = std::move(body_val); cond_evaluator.ResetVisitStates(); @@ -1228,173 +1360,250 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return Status::OK(); } -// Key-value sort is a special snowflake: it's templated on two different -// element types, one for the keys, and one for the values. Jump through some -// hoops to make this work. namespace { -template -StatusOr EvaluateSortInternal(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - auto rank = ShapeUtil::Rank(keys_literal.shape()); - TF_RET_CHECK( - ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) - << "Sort keys and values must have the same dimensions"; - TF_RET_CHECK(sort->operand_count() >= 2) << "Expected key-value sort"; - // We need to sort an array of keys and an array of values, where the - // sorted order of the values is determined by the keys. The simplest(?) - // way to do this is to go to an array-of-pairs representation, sort the - // array using the keys, and then go back to pair-of-arrays. - VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); - VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); - - if (rank == 0) { - // Nothing to sort. - return LiteralUtil::MakeTuple({&keys_literal, &values_literal}); +template +Literal ExtractLiteralFromIndexPositions(const Literal& from, + absl::Span indices, + bool extract_as_scalar) { + if (extract_as_scalar) { + return LiteralUtil::CreateR0(from.Get({indices[0]})); + } + // We use a InlinedVector here because we need to convert it to an + // absl::Span later, and this would not work with std::vector. + absl::InlinedVector values; + for (int64 index : indices) { + values.push_back(from.Get({index})); } + return LiteralUtil::CreateR1(values); +} - Literal keys_result_literal(keys_literal.shape()); - Literal values_result_literal(values_literal.shape()); +StatusOr ExtractFromIndexPositions(const Literal& from, + absl::Span indices, + bool extract_as_scalar = false) { + if (extract_as_scalar) { + CHECK_EQ(indices.size(), 1); + } + PrimitiveType type = from.shape().element_type(); + switch (type) { + case PRED: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U8: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S8: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case BF16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case F16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S16: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case F32: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U32: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S32: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case F64: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case U64: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + case S64: { + return ExtractLiteralFromIndexPositions(from, indices, + extract_as_scalar); + } + default: + return InvalidArgument("Unsupported type for Sort: %s", + PrimitiveType_Name(type)); + } +} +} // namespace + +Status HloEvaluator::HandleSort(HloInstruction* sort) { + TF_RET_CHECK(sort->operand_count() >= 1) + << "Expected at least 1 operand for sort"; + for (int64 i = 1; i < sort->operand_count(); ++i) { + TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(i)->shape())) + << "All Sort operands must have the same dimensions"; + } + + if (VLOG_IS_ON(3)) { + for (int64 i = 0; i < sort->operand_count(); ++i) { + VLOG(3) << "HandleSort operand " << i << " literal: " + << GetEvaluatedLiteralFor(sort->operand(i)).ToString(); + } + } + Shape key_shape = sort->operand(0)->shape(); + auto rank = key_shape.rank(); + PrimitiveType keys_type = key_shape.element_type(); + if (keys_type != F64 && keys_type != U64 && keys_type != S64 && + keys_type != F32 && keys_type != U32 && keys_type != S32 && + keys_type != BF16 && keys_type != F16 && keys_type != U16 && + keys_type != S16 && keys_type != U8 && keys_type != S8) { + return InvalidArgument("Unsupported type for Sort: %s", + PrimitiveType_Name(keys_type)); + } + std::vector result_literals; + result_literals.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + result_literals.emplace_back(sort->operand(i)->shape()); + } std::vector zero_base(rank, 0); std::vector increment(rank, 1); int64 sort_dim = sort->dimensions(0); - int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim); + int64 sort_dim_elements = key_shape.dimensions(sort_dim); increment[sort_dim] = sort_dim_elements; // Iterate through each dimension except 'sort_dim'. TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - keys_literal.shape(), zero_base, - AsInt64Slice(keys_literal.shape().dimensions()), increment, + key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment, [&](absl::Span indices) -> StatusOr { - // Extract a slice from the keys and values literals that correspond to + // Extract a slice from each operand literal that corresponds to // exactly the row in dimension 'sort_dim'. std::vector limit_indices(indices.begin(), indices.end()); - std::for_each(limit_indices.begin(), limit_indices.end(), - [](int64& index) { ++index; }); + absl::c_for_each(limit_indices, [](int64& index) { ++index; }); limit_indices[sort_dim] = sort_dim_elements; - TF_ASSIGN_OR_RETURN(auto keys_to_sort, - keys_literal.Slice(indices, limit_indices) - .Reshape({sort_dim_elements})); - const auto& keys_data = keys_to_sort.data(); - TF_ASSIGN_OR_RETURN(auto values_to_sort, - values_literal.Slice(indices, limit_indices) - .Reshape({sort_dim_elements})); - const auto& values_data = values_to_sort.data(); - using kv_pair = std::pair; - std::vector key_value_vector; - key_value_vector.reserve(keys_data.size()); - for (int i = 0; i < keys_data.size(); ++i) { - key_value_vector.push_back( - std::make_pair(keys_data[i], values_data[i])); + std::vector literals_to_sort; + literals_to_sort.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN(auto literal_to_sort, + GetEvaluatedLiteralFor(sort->operand(i)) + .Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + literals_to_sort.push_back(std::move(literal_to_sort)); } - std::stable_sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess(a.first, b.first); - }); - std::vector result_keys; - // We use a InlinedVector here because we need to convert it to an - // absl::Span later, and this would not work with std::vector. - absl::InlinedVector result_values; - for (const auto& key_value : key_value_vector) { - result_keys.push_back(key_value.first); - result_values.push_back(key_value.second); - } - Literal sorted_keys(ShapeUtil::MakeShape( - keys_literal.shape().element_type(), {sort_dim_elements})); - sorted_keys.PopulateR1(absl::Span(result_keys)); - Literal sorted_values(ShapeUtil::MakeShape( - values_literal.shape().element_type(), {sort_dim_elements})); - sorted_values.PopulateR1(absl::Span(result_values)); + std::vector indices_to_sort(sort_dim_elements); + std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0); + std::stable_sort( + indices_to_sort.begin(), indices_to_sort.end(), + [keys_type, &literals_to_sort](int64 a, int64 b) { + switch (keys_type) { + case F64: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case U64: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case S64: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case F32: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case U32: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case S32: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case BF16: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case F16: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case U16: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case S16: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case U8: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + case S8: { + auto key_lhs = literals_to_sort[0].Get({a}); + auto key_rhs = literals_to_sort[0].Get({b}); + return SafeLess(key_lhs, key_rhs); + } + default: + // We should never reach here, because we checked earlier + // that 'key_type' is one of the cases above. + LOG(FATAL) << "Invalid key type in Sort: %s", + PrimitiveType_Name(keys_type); + return false; + } + }); std::vector slice_dimensions(rank, 1); slice_dimensions[sort_dim] = sort_dim_elements; std::vector start_indices(rank, 0); - TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped, - sorted_keys.Reshape(slice_dimensions)); - TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( - sorted_keys_reshaped, start_indices, indices, slice_dimensions)); - TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped, - sorted_values.Reshape(slice_dimensions)); - TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( - sorted_values_reshaped, start_indices, indices, slice_dimensions)); + for (int64 i = 0; i < sort->operand_count(); ++i) { + TF_ASSIGN_OR_RETURN( + Literal sorted_literal, + ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort)); + TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped, + sorted_literal.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom( + sorted_literal_reshaped, start_indices, indices, + slice_dimensions)); + } return true; })); - Literal result_tuple; - result_tuple = - LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); - VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); - return std::move(result_tuple); -} - -template -StatusOr EvaluateSortCurried(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - switch (values_literal.shape().element_type()) { - case PRED: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case F32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case U32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case S32: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - case BF16: - return EvaluateSortInternal(sort, keys_literal, - values_literal); - default: - return InvalidArgument("Unsupported type for Sort"); - } -} - -StatusOr EvaluateSort(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { - switch (sort->operand(0)->shape().element_type()) { - case F32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case U32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case S32: - return EvaluateSortCurried(sort, keys_literal, values_literal); - case BF16: - return EvaluateSortCurried(sort, keys_literal, values_literal); - default: - return InvalidArgument("Unsupported type for Sort"); - } -} -} // namespace - -Status HloEvaluator::HandleSort(HloInstruction* sort) { - if (!ShapeUtil::IsTuple(sort->shape())) { - return DefaultAction(sort); + if (sort->operand_count() == 1) { + evaluated_[sort] = std::move(result_literals[0]); } else { - // This is a really stupid work-around for the fact it's hard to support a - // multi-value sort directly, due to the fact we need to template the - // evaluation function on all of the value types. - std::vector sort_results_backing; - for (int64 i = 0; i < sort->operand_count(); ++i) { - auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), - GetEvaluatedLiteralFor(sort->operand(i))); - if (!result.ok()) { - return result.status(); - } - sort_results_backing.push_back( - std::move(result.ValueOrDie().DecomposeTuple()[1])); - } - std::vector sort_results; - absl::c_transform(sort_results_backing, std::back_inserter(sort_results), + std::vector literal_ptrs; + absl::c_transform(result_literals, std::back_inserter(literal_ptrs), [](const Literal& literal) { return &literal; }); - evaluated_[sort] = LiteralUtil::MakeTuple(sort_results); - return Status::OK(); + + Literal result_tuple = LiteralUtil::MakeTuple(literal_ptrs); + VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); + + evaluated_[sort] = std::move(result_tuple); } + return Status::OK(); } Status HloEvaluator::HandleReduce(HloInstruction* reduce) { - if (!ShapeUtil::IsTuple(reduce->shape())) { + if (!reduce->shape().IsTuple()) { return DefaultAction(reduce); } else { auto first_element_type = reduce->shape().tuple_shapes(0).element_type(); @@ -1409,6 +1618,27 @@ Status HloEvaluator::HandleReduce(HloInstruction* reduce) { } } +Status HloEvaluator::HandleCustomCall(HloInstruction* custom_call) { + if (!custom_call_handler_) { + // No handler is registered; this means custom-calls are not allowed. + return DefaultAction(custom_call); + } + + // Evaluate input operands so the handler has access to the operand data. + std::vector operands; + operands.reserve(custom_call->operand_count()); + for (const HloInstruction* operand : custom_call->operands()) { + operands.push_back(&GetEvaluatedLiteralFor(operand)); + } + + // Synchronously issue the handler to populate the instruction output literal. + TF_ASSIGN_OR_RETURN( + auto output, custom_call_handler_(custom_call, absl::MakeSpan(operands))); + + evaluated_[custom_call] = std::move(output); + return Status::OK(); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return ShapeUtil::ValidateShape(hlo->shape()); @@ -1426,16 +1656,46 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { return Status::OK(); } -// Explicit instantiation of templatized Evaluate* methods. -// -template StatusOr HloEvaluator::Evaluate( - const HloModule& module, absl::Span arg_literals); +namespace { +template +std::unique_ptr> MatmulArray2DImpl( + const Array2D& lhs, const Array2D& rhs, + const std::function& impl_fn) { + CHECK_EQ(lhs.width(), rhs.height()); + int m = lhs.height(); + int n = rhs.width(); + int k = lhs.width(); + auto result = absl::make_unique>(m, n); + // Because Eigen is a header-oriented library, make sure that the Eigen code + // is the same as the code used by the CPU backend (otherwise the linker will + // randomly pick *some* definition). + impl_fn( + /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, + k, + /*transpose_lhs=*/0, + /*transpose_rhs=*/0); + return result; +} +} // namespace + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); +} -template StatusOr HloEvaluator::Evaluate( - const HloComputation& computation, - absl::Span arg_literals); +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); +} -template StatusOr HloEvaluator::Evaluate( - HloInstruction* instruction, absl::Span arg_literals); +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index d847900010c697d7d280ed8e4a9502f1c465ee07..72ea40bcd797def3bc0765986881792b8752d9e1 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -16,12 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#include #include #include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -42,16 +46,24 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // specified. explicit HloEvaluator(int64 max_loop_iterations = -1); - // Evaluates an HLO module and an array of pointers to literals. - // Returns the evaluated result as a literal if successful. + // Evaluates an HLO module and an array of pointers to literals. Returns the + // evaluated result as a literal if successful. + // // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template + // + // (Dummy template arg is to reduce the overloading priority of one overload + // so that Evaluate(module, {}) resolves unambiguously.) + StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals) { + return Evaluate(*module.entry_computation(), arg_literals); + } + template StatusOr Evaluate(const HloModule& module, - absl::Span arg_literals); + absl::Span arg_literals) { + return Evaluate(*module.entry_computation(), arg_literals); + } // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -69,29 +81,24 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template + // + // (Dummy template arg is to reduce the overloading priority of one overload + // so that Evaluate(module, {}) resolves unambiguously.) StatusOr Evaluate(const HloComputation& computation, - absl::Span arg_literals); - - // Evaluates a single HLO instruction and an array of pointers to literals. - // Return the evaluated result as literal if successful. - // Precondition: - // 1. argument literals correspond to the input instruction's parameters in - // their post-ordering. - // 2. the instruction's operands must be of either Parameter or Constant type. - // `LiteralPtr` accepts either Literal or const Literal* - // type. - template - StatusOr Evaluate(HloInstruction* instruction, - absl::Span arg_literals); - - // Evaluates a single HLO instruction with constant operands. - // Returns the evaluated result as literal if successful. - // Precondition: - // 1. all operands of the input instruction are constants. - // 2. the instruction is not a Parameter operation. + absl::Span arg_literals); + template + StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& l : arg_literals) { + arg_literal_ptrs.push_back(&l); + } + return Evaluate(computation, arg_literal_ptrs); + } + + // Gets the value of running a single HLO instruction. + // + // All of the operands to this instruction must be constants. StatusOr Evaluate(HloInstruction* instruction); // Same as Evaluate, except returning false on error and accepts an output @@ -119,6 +126,39 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs); + void set_dynamic_dimension_inference( + DynamicDimensionInference* dynamic_dimension_inference) { + dynamic_dimension_inference_ = dynamic_dimension_inference; + } + + // Enable the fast path for certain operations like dot or convolution. + void set_use_fast_path(bool value) { use_fast_path_ = value; } + + // Handles evaluation of a custom-call op. + // Operand literals are provided in |operands| and implementations must + // populate |output| before returning. + using CustomCallHandler = std::function( + HloInstruction* custom_call, absl::Span operands)>; + + // Sets a handler that is called during evaluation for custom-call ops. + // If no handler is defined the default error behavior will occur. The handler + // will be provided evaluated literals for all operands and is expected to + // return an output literal of the appropriate shape. + void set_custom_call_handler( + std::function(HloInstruction* custom_call, + absl::Span operands)> + handler) { + custom_call_handler_ = std::move(handler); + } + + // Returns the result of a matrix multiply `lhs x rhs`. + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); + protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. @@ -144,6 +184,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Operations that are type-agnostic or always return a specific type, such as // HandleIsFinite where boolean is always returned. // + Status HandleBitcast(HloInstruction* bitcast) override; + + Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override; + Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant) override; @@ -190,16 +234,51 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleImag(HloInstruction* imag) override; + Status HandleComplex(HloInstruction* complex) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + + // Unsupported HLOs, note some of them (such as BatchNorm*) are typically + // expanded in a semantic-preserving way into other HLOs by adding exanpsion + // HLO pass to the HLO optimization pass during compilation, which can then be + // handled by the evaluator. + Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override { + return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator."); + }; + Status HandleBatchNormInference( + HloInstruction* batch_norm_inference) override { + return Unimplemented( + "BatchNormInference HLO is unsupported by the evaluator."); + }; + Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { + return Unimplemented( + "BatchNormTraining HLO is unsupported by the evaluator."); + }; + Status HandleInfeed(HloInstruction* infeed) override { + return Unimplemented("Infeed HLO is unsupported by the evaluator."); + }; + Status HandleOutfeed(HloInstruction* outfeed) override { + return Unimplemented("Outfeed HLO is unsupported by the evaluator."); + }; + // Returns the already-evaluated literal result for the instruction. + // // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. + // + // Similarly, a Parameter instruction is considered evaluated and its literal + // is looked up in arg_literals. + // // Crash with log if the given instruction has not been evaluated previously. const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { if (hlo->IsConstant()) { return hlo->literal(); } + if (hlo->opcode() == HloOpcode::kParameter) { + return *arg_literals_.at(hlo->parameter_number()); + } auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); @@ -207,14 +286,23 @@ class HloEvaluator : public DfsHloVisitorWithDefault { } // Tracks the HLO instruction and its evaluated literal result. + // + // Parameters and constants aren't stored here, see implementation of + // GetEvaluatedLiteralFor. + // // TODO(b/35950897): have better memory management here to free instructions // that are no longer a parent for any other subsequent instruction in // post-orderring. + // // Must be cleared for each evaluation. - // Storing Literal in place require the container to have pointer stability so - // we cannot use flat_hash_map any more. + // + // Storing Literal in place requires the container to have pointer stability + // so we cannot use flat_hash_map any more. absl::node_hash_map evaluated_; + // Use fast path that uses eigen in the evaluator. + bool use_fast_path_ = false; + private: template static StatusOr ElementWiseUnaryOpImpl( @@ -245,9 +333,25 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Max loop iterations to execute with no maximum if negative. int64 max_loop_iterations_; + // Module-level seed handle. + uint64 seed_; + // RNG engine. + std::minstd_rand0 engine_; + + // DynamicDimensionInference is used to evaluate GetDimensionSize, which + // returns the dynamic dimension size of its operand. + DynamicDimensionInference* dynamic_dimension_inference_; + + // Optional handler for custom_call ops. + std::function(HloInstruction* custom_call, + absl::Span operands)> + custom_call_handler_; + TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; +std::unique_ptr> MatmulArray2D(const Array2D& lhs, + const Array2D& rhs); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index d95b6ad04f2c446b423a3aaef4de333ed2968883..fb8cd299cef06d549130cd56dd2c430c4c1a0387 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,20 +51,18 @@ namespace { static std::array use_bf16_params{true, false}; -class HloEvaluatorTest : public ::testing::WithParamInterface, - public HloTestBase { - protected: - HloEvaluatorTest() : HloTestBase(), use_bfloat16_(GetParam()) { - evaluator_ = absl::make_unique(); - } +// Test fixture for the HloEvaluator. +// +// In bf16 mode, all f32 shapes are converted to bf16 before running. +class HloEvaluatorTest : public HloTestBase { + public: + HloEvaluatorTest() : use_bfloat16_(false) {} Literal Evaluate(absl::Span arg_literals = {}) { if (use_bfloat16_) { - // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. - auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(m_.get()).ValueOrDie(); + HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*m_->entry_computation(), arg_literals) + return evaluator_.Evaluate(*m_->entry_computation(), arg_literals) .ConsumeValueOrDie(); } @@ -72,16 +72,12 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, Literal EvaluateWithModule( HloModule* module, absl::Span arg_literals = {}) { if (use_bfloat16_) { - // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. - auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(module).ValueOrDie(); + HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*module->entry_computation(), arg_literals) + return evaluator_.Evaluate(*module->entry_computation(), arg_literals) .ConsumeValueOrDie(); } - std::unique_ptr evaluator_; - void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, float aabs = 0) { HloComputation::Builder b(TestName()); @@ -115,16 +111,27 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } - bool use_bfloat16_; + protected: + explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {} + HloEvaluator evaluator_; + + const bool use_bfloat16_; std::unique_ptr m_ = CreateNewVerifiedModule(); }; -#define XLA_TYPED_TEST_P(test_case_name, test_name, test_type1) \ - TEST_P(test_case_name, test_name) +// Lets you write TEST_Ps that run twice, once with and once without bf16. +class HloEvaluatorBf16Test : public ::testing::WithParamInterface, + public HloEvaluatorTest { + protected: + HloEvaluatorBf16Test() : HloEvaluatorTest(/*use_bfloat16=*/GetParam()) {} +}; + +INSTANTIATE_TEST_SUITE_P(HloEvaluatorTest_Instantiation, HloEvaluatorBf16Test, + ::testing::ValuesIn(use_bf16_params)); // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. -TEST_P(HloEvaluatorTest, DoesClamp) { +TEST_P(HloEvaluatorBf16Test, DoesClamp) { auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); @@ -145,7 +152,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { +TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) { auto low = LiteralUtil::CreateR0(0.f); auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); auto high = LiteralUtil::CreateR0(1.f); @@ -168,7 +175,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. -TEST_P(HloEvaluatorTest, DoesSelect) { +TEST_P(HloEvaluatorBf16Test, DoesSelect) { auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); @@ -193,7 +200,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. -TEST_P(HloEvaluatorTest, DoesAdd) { +TEST_F(HloEvaluatorTest, DoesAdd) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); @@ -202,7 +209,7 @@ TEST_P(HloEvaluatorTest, DoesAdd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise and with 2 operands. -TEST_P(HloEvaluatorTest, DoesAnd) { +TEST_P(HloEvaluatorBf16Test, DoesAnd) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {4, 4}}); @@ -211,7 +218,7 @@ TEST_P(HloEvaluatorTest, DoesAnd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. -TEST_P(HloEvaluatorTest, DoesOr) { +TEST_F(HloEvaluatorTest, DoesOr) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-100, 4}}); @@ -220,7 +227,7 @@ TEST_P(HloEvaluatorTest, DoesOr) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. -TEST_P(HloEvaluatorTest, DoesXor) { +TEST_F(HloEvaluatorTest, DoesXor) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{3, 4}, {-104, 0}}); @@ -229,7 +236,7 @@ TEST_P(HloEvaluatorTest, DoesXor) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise multiply with 2 operands. -TEST_P(HloEvaluatorTest, DoesMultiply) { +TEST_F(HloEvaluatorTest, DoesMultiply) { auto lhs = LiteralUtil::CreateR2({{-1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2( {{std::numeric_limits::min(), 4}, {4, 4}}); @@ -240,14 +247,14 @@ TEST_P(HloEvaluatorTest, DoesMultiply) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. -TEST_P(HloEvaluatorTest, DoesDivideInt64) { +TEST_F(HloEvaluatorTest, DoesDivideInt64) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } -TEST_P(HloEvaluatorTest, DoesDivideDouble) { +TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) { auto lhs = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); auto rhs = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = @@ -258,41 +265,41 @@ TEST_P(HloEvaluatorTest, DoesDivideDouble) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. -TEST_P(HloEvaluatorTest, DoesAbsR2) { +TEST_F(HloEvaluatorTest, DoesAbsR2) { auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesAbsR0) { +TEST_P(HloEvaluatorBf16Test, DoesAbsR0) { auto operand = LiteralUtil::CreateR0(-1.0f); auto expected = LiteralUtil::CreateR0(1.0f); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) { +TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) { auto operand = LiteralUtil::CreateR1({}); auto expected = LiteralUtil::CreateR1({}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesNegateR2) { +TEST_F(HloEvaluatorTest, DoesNegateR2) { auto operand = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {-1, 4}}); auto expected = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {1, -4}}); TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); } -TEST_P(HloEvaluatorTest, DoesCosR2) { +TEST_P(HloEvaluatorBf16Test, DoesCosR2) { auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = LiteralUtil::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } -TEST_P(HloEvaluatorTest, DoesSinR2) { +TEST_P(HloEvaluatorBf16Test, DoesSinR2) { auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } -TEST_P(HloEvaluatorTest, DoesNotR2) { +TEST_F(HloEvaluatorTest, DoesNotR2) { auto operand = LiteralUtil::CreateR2({{0, std::numeric_limits::min()}, {-1, std::numeric_limits::max()}}); @@ -301,9 +308,22 @@ TEST_P(HloEvaluatorTest, DoesNotR2) { {0, std::numeric_limits::min()}}); TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand)); } + +TEST_F(HloEvaluatorTest, DoesRealC128) { + auto x = LiteralUtil::CreateR1({{1, 0}, {-100, 4}}); + auto expected_real = LiteralUtil::CreateR1({1, -100}); + TestUnaryOp(HloOpcode::kReal, std::move(expected_real), std::move(x)); +} + +TEST_F(HloEvaluatorTest, DoesImagC128) { + auto x = LiteralUtil::CreateR1({{1, 0}, {-100, 4}}); + auto expected_imag = LiteralUtil::CreateR1({0, 4}); + TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x)); +} + // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. -TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { +TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); @@ -333,7 +353,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { } // Verifies Reshape operation is correctly evaluated. -TEST_P(HloEvaluatorTest, DoesReshape) { +TEST_F(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, @@ -359,7 +379,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { } // Verifies Broadcast operation is correctly evaluated. -TEST_P(HloEvaluatorTest, DoesBroadcast) { +TEST_F(HloEvaluatorTest, DoesBroadcast) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto output_literal = LiteralUtil::CreateR3( @@ -375,7 +395,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } -TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { +TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR0(111); auto output_literal = LiteralUtil::CreateR2( @@ -394,7 +414,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } -TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { +TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( @@ -416,7 +436,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { +TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction( @@ -437,7 +457,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { +TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); @@ -456,7 +476,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } -TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { +TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) { HloComputation::Builder b(TestName()); auto input_literal = LiteralUtil::CreateR2WithLayout( @@ -489,7 +509,7 @@ PaddingConfig CreatePaddingConfig( return padding_config; } -TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { +TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto operand = LiteralUtil::CreateR2({{}, {}}); HloComputation::Builder b(TestName()); auto operand_instruction = @@ -514,7 +534,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { +TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) { HloComputation::Builder b(TestName()); Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); @@ -549,7 +569,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, NegativePadding2D) { +TEST_P(HloEvaluatorBf16Test, NegativePadding2D) { HloComputation::Builder b(TestName()); // input_array: @@ -591,7 +611,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250))); } -TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { +TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) { HloComputation::Builder b(TestName()); // f32[4,3] { @@ -630,7 +650,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank2AndRank1) { +TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) { HloComputation::Builder b(TestName()); // lhs: @@ -676,7 +696,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank1AndRank2) { +TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -714,7 +734,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DotRank2AndRank2) { +TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -764,7 +784,51 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SimpleConv1D) { +TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) { + HloComputation::Builder b(TestName()); + + auto lhs_array = absl::make_unique>(2, 2, 3, 1); + lhs_array->FillIota(1.0f); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(*lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + auto rhs_array = absl::make_unique>(2, 2, 3, 1); + rhs_array->FillIota(2.0f); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(*rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 1, 1}); + DotDimensionNumbers dot_dnums; + + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(2); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); + m_->AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + float expected_1 = 0; + for (float i = 1.0f; i < 7.0f; ++i) { + expected_1 += i * i + i; + } + float expected_2 = 0; + for (float i = 7.0f; i < 13.0f; ++i) { + expected_2 += i * i + i; + } + auto expected_array = Array3D({{{expected_1}}, {{expected_2}}}); + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_P(HloEvaluatorBf16Test, SimpleConv1D) { HloComputation::Builder b(TestName()); Array3D lhs_array = {{{1, 2, 3}}}; @@ -802,7 +866,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -813,7 +877,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { +TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -857,7 +921,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -876,7 +940,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { +TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) { HloComputation::Builder b(TestName()); // clang-format off @@ -941,7 +1005,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -957,7 +1021,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { +TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) { HloComputation::Builder b(TestName()); // clang-format off @@ -1019,7 +1083,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1035,7 +1099,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { +TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1079,7 +1143,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1099,7 +1163,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { +TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1143,7 +1207,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1164,7 +1228,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, +TEST_P(HloEvaluatorBf16Test, DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) { HloComputation::Builder b(TestName()); @@ -1215,7 +1279,7 @@ TEST_P(HloEvaluatorTest, Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, - window, dnums, DefaultPrecisionConfig(2))); + /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1237,7 +1301,7 @@ TEST_P(HloEvaluatorTest, EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { +TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) { HloComputation::Builder b(TestName()); std::vector input_dims = {1, 2, 2, 4}; std::vector filter_dims = {2, 2, 2, 8}; @@ -1286,7 +1350,8 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, - /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/2, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1372,7 +1437,7 @@ void BM_ReducePrecisely(int num_iters) { BENCHMARK(BM_ReducePrecisely); -TEST_P(HloEvaluatorTest, ReduceAdd) { +TEST_P(HloEvaluatorBf16Test, ReduceAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1414,7 +1479,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowMax) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) { HloComputation::Builder b(TestName()); // arg: @@ -1465,7 +1530,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxWindowDilation) { HloComputation::Builder b(TestName()); // arg: @@ -1517,7 +1582,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowAdd) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1574,7 +1639,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { +TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) { HloComputation::Builder b(TestName()); // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. @@ -1637,7 +1702,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); } -TEST_P(HloEvaluatorTest, StridedSlice) { +TEST_P(HloEvaluatorBf16Test, StridedSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1671,7 +1736,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DynamicSlice) { +TEST_P(HloEvaluatorBf16Test, DynamicSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1687,12 +1752,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); + auto zero = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); - b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, - start_indices, {2, 3})); + b.AddInstruction( + HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1707,7 +1774,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // Verifies that the HloEvaluator's implementation goes along with existing // backends' behavior, although this is not required by the spec. -TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { +TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1723,12 +1790,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2, 1}))); + auto two = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); - b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, - start_indices, {2, 3})); + b.AddInstruction( + HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1741,7 +1810,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { +TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) { HloComputation::Builder b(TestName()); // arg: @@ -1757,15 +1826,17 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); - auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); + auto zero = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto one = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto update = b.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{-2.0, -3.0}, {-6.0, -7.0}}))); Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - shape, operand, update, start_indices)); + shape, operand, update, {zero, one})); m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1778,7 +1849,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SetAndGetTuples) { +TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1814,7 +1885,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { +TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1853,7 +1924,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, Reverse) { +TEST_P(HloEvaluatorBf16Test, Reverse) { HloComputation::Builder b(TestName()); // Input shape is float[4x3x2x1]. @@ -1906,7 +1977,7 @@ TEST_P(HloEvaluatorTest, Reverse) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { +TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1930,7 +2001,7 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Check that EvaluateWithSubstitutions works if one of the operands to the op // we're evaluating is a constant. -TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { +TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1953,7 +2024,7 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { const char* hlo_text = R"( HloModule TensorFlowGatherV1 @@ -1977,7 +2048,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { const char* hlo_text = R"( HloModule TensorFlowGatherV2 @@ -2001,7 +2072,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { const char* hlo_text = R"( HloModule TensorFlowGatherMultipleBatchDims @@ -2026,7 +2097,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { const char* hlo_text = R"( HloModule TensorFlowGatherNd @@ -2052,7 +2123,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, +TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) { const char* hlo_text = R"( HloModule TensorFlowGatherNd @@ -2079,7 +2150,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { +TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) { const char* hlo_text = R"( HloModule DynamicSlice @@ -2102,7 +2173,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { +TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { const char* hlo_text = R"( HloModule BatchDynamicSlice @@ -2126,7 +2197,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { +TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { const char* hlo_text = R"( HloModule TensorFlowGatherV1 @@ -2148,7 +2219,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { +TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { const string hlo_text = R"( HloModule GatherXd @@ -2173,7 +2244,7 @@ ENTRY main { Evaluate({&operand, &start_indices}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { const char* hlo_text = R"( HloModule TensorFlowScatterV1 @@ -2204,7 +2275,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { const char* hlo_text = R"( HloModule TensorFlowScatterV2 @@ -2236,7 +2307,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2268,7 +2339,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2300,7 +2371,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { +TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2334,7 +2405,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { const char* hlo_text = R"( HloModule TensorFlowScatter @@ -2366,7 +2437,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { const char* hlo_text = R"( HloModule TensorFlowScatterMultipleBatchDims @@ -2399,7 +2470,7 @@ ENTRY main { Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { const char* hlo_text = R"( HloModule TensorFlowScatterNd @@ -2435,7 +2506,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, +TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) { const char* hlo_text = R"( HloModule TensorFlowScatterNdNonDefaultIndexVectorDim @@ -2472,7 +2543,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { +TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { const char* hlo_text = R"( HloModule DynamicUpdateSlice @@ -2504,7 +2575,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { +TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { const char* hlo_text = R"( HloModule BatchDynamicUpdateSlice @@ -2536,7 +2607,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { +TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { const char* hlo_text = R"( HloModule TensorFlowScatter_ZeroDimBounds @@ -2565,7 +2636,7 @@ ENTRY main { operand, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { +TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { const string hlo_text = R"( HloModule Scatter_NoUpdateWindowDims @@ -2598,7 +2669,7 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { const char* hlo_text = R"( HloModule TensorFlowScatter_NegativeIndices @@ -2633,7 +2704,7 @@ ENTRY main { {&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_OobIndices) { +TEST_F(HloEvaluatorTest, EvaluateScatter_OobIndices) { const string hlo_text = R"( HloModule BatchDynamicUpdateSlice @@ -2669,7 +2740,7 @@ ENTRY main { {&operand, &scatter_indices, &updates}))); } -TEST_P(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { +TEST_F(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { const char* hlo_text = R"( HloModule TensorFlowScatterNd_OobUpdateWindow @@ -2708,7 +2779,7 @@ ENTRY main { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise comparison with 2 bfloat16 operands. -TEST_P(HloEvaluatorTest, DoesCompareBF16) { +TEST_F(HloEvaluatorTest, DoesCompareBF16) { // lhs >= rhs auto lhs = LiteralUtil::CreateR2( {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)}, @@ -2722,7 +2793,7 @@ TEST_P(HloEvaluatorTest, DoesCompareBF16) { std::move(rhs)); } -TEST_P(HloEvaluatorTest, Bf16Reduction) { +TEST_P(HloEvaluatorBf16Test, Bf16Reduction) { const string hlo_text = R"( HloModule Bf16Reduction @@ -2746,7 +2817,7 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); } -TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { +TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) { // Regression test for b/114735354. const string hlo_text = R"( HloModule SliceWithDifferentLayout @@ -2765,8 +2836,322 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } -INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, - ::testing::ValuesIn(use_bf16_params)); +TEST_P(HloEvaluatorBf16Test, Bitcast) { + // Regression test for b/114735354. + constexpr absl::string_view hlo_text_base = R"( +HloModule Bitcast + +ENTRY main { + param = %s[32,121]{1,0} parameter(0) + ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param) +} +)"; + string hlo_text; + if (use_bfloat16_) { + hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16"); + } else { + hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32"); + } + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + if (use_bfloat16_) { + EXPECT_TRUE( + absl::c_equal(args[0].data(), actual.data())); + } else { + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); + } +} + +// Check that s32 under/overflow doesn't trigger a ubsan failure. +TEST_F(HloEvaluatorTest, Int32Overflow) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + c1 = s32[] constant(1073741824) // 2^30 + sum = s32[] add(c1, c1) // 2^31, i.e. INT_MIN + + c2 = s32[] constant(-2147483648) // -2^31 + sub = s32[] subtract(c2, c1) // -2^31 - 2^30, underflows + + mul = s32[] multiply(c1, c1) + ROOT tuple = (s32[], s32[], s32[]) tuple(sum, sub, mul) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + std::vector actual = Evaluate({}).DecomposeTuple(); + ASSERT_EQ(actual.size(), 3); + + uint32 pow30 = uint32{1} << 30; + uint32 pow31 = uint32{1} << 31; + EXPECT_EQ(actual[0].GetFirstElement(), static_cast(pow31)); + EXPECT_EQ(actual[1].GetFirstElement(), + static_cast(-(pow31 + pow30))); + EXPECT_EQ(actual[2].GetFirstElement(), + static_cast(pow31 * pow31)); +} + +TEST_F(HloEvaluatorTest, GetDimensionSize) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + size = u32[] parameter(0) + + data = s32[4] parameter(1) + + sum = s32[4] add(data, data) + + ROOT dynamic_size = u32[] get-dimension-size(sum), dimensions={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + + // Set up dynamic parameter binding. + TF_CHECK_OK(m_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + TF_ASSERT_OK_AND_ASSIGN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(m_.get())); + + evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference); + Literal size_arg = LiteralUtil::CreateR0(3); + Literal data_arg = LiteralUtil::CreateR1({1, 2, 3, 4}); + + Literal actual = Evaluate({&size_arg, &data_arg}); + + EXPECT_EQ(actual.GetFirstElement(), static_cast(3)); +} + +// Check that we get a useful error if we pass inputs of the wrong shape. +TEST_F(HloEvaluatorTest, EvaluateWithWrongInputShapes) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + p0 = s32[1] parameter(0) + ROOT sum = s32[1] add(p0, p0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal input_wrong_shape = LiteralUtil::CreateR1({0, 1}); + + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_, {&input_wrong_shape}) + .status() + .error_message(), + "Shape mismatch at parameter 0. Computation expected s32[1]{0}, " + "but arg was s32[2]."); + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_->entry_computation(), {&input_wrong_shape}) + .status() + .error_message(), + "Shape mismatch at parameter 0. Computation expected s32[1]{0}, " + "but arg was s32[2]."); +} + +// Check that we get a useful error if we pass too many or too few inputs. +TEST_F(HloEvaluatorTest, EvaluateWithWrongNumberOfInputs) { + constexpr absl::string_view hlo_text = R"( +HloModule Test + +ENTRY main { + p0 = s32[1] parameter(0) + ROOT sum = s32[1] add(p0, p0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + Literal input = LiteralUtil::CreateR1({0}); + + EXPECT_EQ( + HloEvaluator().Evaluate(*m_, {&input, &input}).status().error_message(), + "Expected 1 argument, but got 2."); + EXPECT_EQ(HloEvaluator() + .Evaluate(*m_->entry_computation(), {&input, &input}) + .status() + .error_message(), + "Expected 1 argument, but got 2."); +} + +TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule FusionInputLayout + + fused_computation { + param_0 = f32[20,20]{0,1} parameter(0) + ROOT bitcast = f32[20,20]{1,0} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{0,1} parameter(0) + ROOT fusion = f32[20,20]{1,0} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); +} + +TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule FusionOutputLayout + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + ROOT bitcast = f32[20,20]{0,1} bitcast(param_0) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = f32[20,20]{0,1} fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); +} + +TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) { + constexpr absl::string_view hlo_text = R"( + HloModule MOFusionOutputLayout + + fused_computation { + param_0 = f32[20,20]{1,0} parameter(0) + bitcast = f32[20,20]{0,1} bitcast(param_0) + ROOT tuple = (f32[20,20]{0,1}) tuple(bitcast) + } + + ENTRY kernel_entry { + parameter.0 = f32[20,20]{1,0} parameter(0) + ROOT fusion = (f32[20,20]{0,1}) fusion(parameter.0), + kind=kLoop, calls=fused_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual_tuple = Evaluate({&args[0]}); + std::vector actual_literals = actual_tuple.DecomposeTuple(); + EXPECT_TRUE( + absl::c_equal(args[0].data(), actual_literals[0].data())); +} + +// Tests that custom_calls fail to evaluate when no handler is specified. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_NoHandler) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_NoHandler + ENTRY kernel_entry { + parameter.0 = u32[2,2]{1,0} parameter(0) + ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + EXPECT_EQ(HloEvaluator().Evaluate(*m_, {&args[0]}).status().code(), + ::tensorflow::error::UNIMPLEMENTED); +} + +// Tests when a custom_call handler returns an error. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_HandlerError + ENTRY kernel_entry { + parameter.0 = u32[2,2]{1,0} parameter(0) + ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + HloEvaluator evaluator; + evaluator.set_custom_call_handler( + [](HloInstruction* custom_call, absl::Span operands) { + return InternalError("Test error"); + }); + EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(), + ::tensorflow::error::INTERNAL); +} + +// Tests the custom_call handler on calls with many inputs. +// We sum the operands so that we can verify the operand and output literals +// are properly mapped for access. +TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) { + constexpr absl::string_view hlo_text = R"( + HloModule EvaluateCustomCall_ManyInputs + ENTRY kernel_entry { + parameter.0 = u32[1]{0} parameter(0) + parameter.1 = u32[1]{0} parameter(1) + ROOT test_root = u32[1]{0} custom-call(parameter.0, parameter.1), + custom_call_target="_my_custom_call" + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + HloEvaluator evaluator; + evaluator.set_custom_call_handler( + [](HloInstruction* custom_call, absl::Span operands) { + EXPECT_EQ(HloOpcode::kCustomCall, custom_call->opcode()); + EXPECT_EQ("_my_custom_call", custom_call->custom_call_target()); + EXPECT_EQ(2, custom_call->operand_count()); + EXPECT_EQ(2, operands.size()); + auto output = Literal::CreateFromShape(custom_call->shape()); + auto operand0_data = operands[0]->data(); + auto operand1_data = operands[1]->data(); + auto output_data = output.data(); + output_data[0] = operand0_data[0] + operand1_data[0]; + return output; + }); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + evaluator.Evaluate(*m_->entry_computation(), {&args[0], &args[1]})); + auto arg0_data = args[0].data(); + auto arg1_data = args[1].data(); + std::vector expected_data = {arg0_data[0] + arg1_data[0]}; + EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data())); +} + +TEST_F(HloEvaluatorTest, IsFiniteF16) { + constexpr absl::string_view hlo_text = R"( + HloModule test + + ENTRY IsFiniteTest { + c = f16[6] constant({nan, 7, nan, -1, inf, -inf}) + ROOT is-finite = pred[6] is-finite(c) + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_THAT(actual_literal.data(), + ::testing::ElementsAre(false, true, false, true, false, false)); +} + +TEST_F(HloEvaluatorTest, IsFiniteBf16) { + constexpr absl::string_view hlo_text = R"( + HloModule test + + ENTRY IsFiniteTest { + c = bf16[6] constant({nan, 7, nan, -1, inf, -inf}) + ROOT is-finite = pred[6] is-finite(c) + })"; + + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + Literal actual_literal, + HloEvaluator().Evaluate(*m_->entry_computation(), {})); + EXPECT_THAT(actual_literal.data(), + ::testing::ElementsAre(false, true, false, true, false, false)); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b87fc3e34012e75ee07bff6c1e113dce404f83cb..742a389ed04eb7303197467587223486c780a31e 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -22,7 +22,9 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/meta/type_traits.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -38,9 +40,8 @@ namespace xla { // Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is // a "private" header that's not exposed outside of hlo_evaluator.cc. template -using is_complex_t = std::is_same; -template -using is_complex64_t = std::is_same; +using is_complex_t = + absl::disjunction, std::is_same>; // It's UB to use std::sort with std::less, because of NaNs. Define // "safe" less functions which are actually strict weak orders. -NaN and NaN @@ -82,6 +83,26 @@ bool SafeLess(const NativeT& a, const NativeT& b) { return SafeLess(static_cast(a), static_cast(b)); } +// ToArithmeticSafeType(T t): +// - converts `t` to the bitwise-equivalent `unsigned T` if T is a signed +// integer, and +// - otherwise returns `t` unchanged. +// +// It's UB in C++ to under/overflow a signed integer, so we wrap all arithmetic +// in this type to force 2's complement behavior. +template ::value && + std::is_signed::value>::type* = nullptr> +typename std::make_unsigned::type ToArithmeticSafeType(T t) { + return static_cast::type>(t); +} +template ::value || + !std::is_signed::value>::type* = nullptr> +T ToArithmeticSafeType(T t) { + return std::move(t); +} + // Templated DfsHloVisitor for use by HloEvaluator. // // Typically ReturnT here indicates the resulting literal type of each evaluated @@ -105,6 +126,12 @@ bool SafeLess(const NativeT& a, const NativeT& b) { template class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { private: + Status UnsupportedTypeError(HloInstruction* instruction) { + return InvalidArgument( + "Unsupported type for %s: %s", HloOpcodeString(instruction->opcode()), + PrimitiveType_Name(instruction->shape().element_type())); + } + // Get the value in the given literal static_cast as a double. template < typename NativeT, @@ -185,7 +212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, - typename std::enable_if::value>::type* = nullptr> + typename std::enable_if::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs) { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(abs->operand(0)); @@ -204,6 +231,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // specifying the ElementwiseT explicitly as C64 is needed below. if (abs->operand(0)->shape().element_type() == C64) { return HandleAbs(abs); + } else if (abs->operand(0)->shape().element_type() == C128) { + return HandleAbs(abs); } return HandleAbs(abs); } @@ -224,7 +253,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRound(HloInstruction* round) { - return InvalidArgument("Unsupported type for Round"); + return UnsupportedTypeError(round); } Status HandleRound(HloInstruction* round) override { @@ -246,7 +275,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleCeil(HloInstruction* ceil) { - return InvalidArgument("Unsupported type for Ceil"); + return UnsupportedTypeError(ceil); } Status HandleCeil(HloInstruction* ceil) override { @@ -297,8 +326,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleExpm1(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Expm1"); + Status HandleExpm1(HloInstruction* expm1) { + return UnsupportedTypeError(expm1); } Status HandleExpm1(HloInstruction* floor) override { @@ -321,7 +350,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleFloor(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Floor"); + return UnsupportedTypeError(floor); } Status HandleFloor(HloInstruction* floor) override { @@ -351,12 +380,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleLog1p(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Log1p"); + Status HandleLog1p(HloInstruction* log1p) { + return UnsupportedTypeError(log1p); } - Status HandleLog1p(HloInstruction* floor) override { - return HandleLog1p(floor); + Status HandleLog1p(HloInstruction* log1p) override { + return HandleLog1p(log1p); } template ::value>::type* = nullptr> Status HandleNot(HloInstruction* not_) { - return InvalidArgument("Unsupported type for Not"); + return UnsupportedTypeError(not_); } Status HandleNot(HloInstruction* not_) override { @@ -476,7 +505,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAtan2(HloInstruction* atan2) { - return InvalidArgument("Unsupported type for Atan2"); + return UnsupportedTypeError(atan2); } Status HandleAtan2(HloInstruction* atan2) override { @@ -491,47 +520,25 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template ::value && - !std::is_floating_point::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return NativeT(type(lhs_elem) * type(rhs_elem)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - std::is_floating_point::value || - is_complex_t::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { + Status HandleMultiply(HloInstruction* multiply) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem * rhs_elem; - })); + ElementWiseBinaryOp( + multiply, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return ElementwiseT(ToArithmeticSafeType(lhs_elem) * + ToArithmeticSafeType(rhs_elem)); + })); return Status::OK(); } - Status HandleMultiply(HloInstruction* multiply) override { - return HandleMultiply(multiply); - } - Status HandleSubtract(HloInstruction* subtract) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[subtract], - ElementWiseBinaryOp(subtract, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem - rhs_elem; - })); + ElementWiseBinaryOp( + subtract, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return ElementwiseT(ToArithmeticSafeType(lhs_elem) - + ToArithmeticSafeType(rhs_elem)); + })); return Status::OK(); } @@ -539,7 +546,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem + rhs_elem; + return ElementwiseT(ToArithmeticSafeType(lhs_elem) + + ToArithmeticSafeType(rhs_elem)); })); return Status::OK(); } @@ -624,7 +632,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMaximum(HloInstruction* maximum) { - return InvalidArgument("Unsupported type for Maximum"); + return UnsupportedTypeError(maximum); } Status HandleMaximum(HloInstruction* maximum) override { @@ -659,7 +667,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMinimum(HloInstruction* minimum) { - return InvalidArgument("Unsupported type for Minimum"); + return UnsupportedTypeError(minimum); } Status HandleMinimum(HloInstruction* minimum) override { @@ -667,11 +675,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandlePower(HloInstruction* power) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::pow(lhs_el, rhs_el); - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[power], + ElementWiseBinaryOp( + power, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el == ElementwiseT(0) && rhs_el == ElementwiseT(0) + ? static_cast(1) + : std::pow(lhs_el, rhs_el); + })); return Status::OK(); } @@ -724,7 +735,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { - return InvalidArgument("Unsupported type for Remainder"); + return UnsupportedTypeError(remainder); } Status HandleRemainder(HloInstruction* remainder) override { @@ -746,14 +757,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); + return UnsupportedTypeError(and_); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); + return UnsupportedTypeError(and_); } Status HandleAnd(HloInstruction* and_) override { @@ -775,7 +786,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleOr(HloInstruction* or_) { - return InvalidArgument("Unsupported type for Or"); + return UnsupportedTypeError(or_); } template < @@ -804,14 +815,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleXor(HloInstruction* xor_) { - return InvalidArgument("Unsupported type for Xor"); + return UnsupportedTypeError(xor_); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleXor(HloInstruction* xor_) { - return InvalidArgument("Unsupported type for Xor"); + return UnsupportedTypeError(xor_); } Status HandleXor(HloInstruction* xor_) override { @@ -836,8 +847,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftLeft(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftLeft"); + Status HandleShiftLeft(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftLeft(HloInstruction* shl) override { @@ -866,8 +877,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftRightArithmetic(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + Status HandleShiftRightArithmetic(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftRightArithmetic(HloInstruction* shra) override { @@ -897,8 +908,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || std::is_same::value>::type* = nullptr> - Status HandleShiftRightLogical(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightLogical"); + Status HandleShiftRightLogical(HloInstruction* shift) { + return UnsupportedTypeError(shift); } Status HandleShiftRightLogical(HloInstruction* shrl) override { @@ -911,7 +922,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - return std::fmin(high, std::fmax(value, low)); + if (std::isnan(low) || std::isnan(high)) { + return static_cast(NAN); + } + return static_cast( + std::fmin(high, std::fmax(value, low))); }; TF_ASSIGN_OR_RETURN( parent_->evaluated_[clamp], @@ -923,8 +938,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction*) { - return InvalidArgument("Unsupported type for Clamp"); + Status HandleClamp(HloInstruction* clamp) { + return UnsupportedTypeError(clamp); } Status HandleClamp(HloInstruction* clamp) override { @@ -933,7 +948,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override { CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); - CHECK(ShapeUtil::IsArray(select->shape())); + CHECK(select->shape().IsArray()); std::function select_op = [](bool pred, ReturnT on_true, ReturnT on_false) { if (pred) { @@ -986,8 +1001,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); - CHECK(ShapeUtil::IsArray(lhs_shape)); - CHECK(ShapeUtil::IsArray(rhs_shape)); + CHECK(lhs_shape.IsArray()); + CHECK(rhs_shape.IsArray()); CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); @@ -998,16 +1013,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_GE(num_spatial_dims, 0); CHECK_EQ(window.dimensions_size(), num_spatial_dims); - const auto lhs_rank = ShapeUtil::Rank(lhs_shape); - const auto rhs_rank = ShapeUtil::Rank(rhs_shape); + const auto lhs_rank = lhs_shape.rank(); + const auto rhs_rank = rhs_shape.rank(); CHECK_EQ(num_spatial_dims + 2, lhs_rank); CHECK_EQ(num_spatial_dims + 2, rhs_rank); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums)); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, conv->feature_group_count(), + conv->batch_group_count(), window, dnums)); CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1030,12 +1045,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto lhs_literal_data = lhs_literal.data(); auto rhs_literal_data = rhs_literal.data(); - int64 feature_group_count = conv->feature_group_count(); + const int64 feature_group_count = conv->feature_group_count(); + const int64 batch_group_count = conv->batch_group_count(); auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data, - feature_group_count](const absl::Span out_index) { + rhs_literal_data, feature_group_count, + batch_group_count](const absl::Span out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1048,6 +1064,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 input_z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + + const int64 input_batch_size = + ShapeUtil::GetDimension(lhs_shape, input_batch_dim); + + const int64 batch_group_size = input_batch_size / batch_group_count; + // The size of an input feature group. const int64 input_feature_group_size = input_z_size / feature_group_count; @@ -1063,11 +1085,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 feature_group_index = out_index[output_z_dim] / output_feature_group_size; + const int64 batch_group_index = out_index[output_z_dim]; + ElementwiseT result_val = static_cast(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), 0); // Convolve input feature with kernel. + // The mechanism indexes into the correct LHS (input) and RHS (kernel) + // locations and accumulates multiplications for a given output index. do { // Find corresponding spatial dimension index for input (lhs). int64 lhs_linear_spatial_index = 0; @@ -1120,11 +1146,24 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { feature_group_index * input_feature_group_size + rhs_iz; int64 lhs_linear_index = lhs_linear_spatial_index; + lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; + + // We are scraping only the diagonal elements in the resultant + // convolution output when batch_group_count is greater than 1, + // where 1 is the default. No scraping is done in that case. + // This approach works out automatically for 'groups' in batches + // with group_size > 1, because we already descend down the batch + // dimension for the 'output_batch_dim' above. + lhs_linear_index += + ((batch_group_index * batch_group_size) % input_batch_size) * + lhs_dim_multipliers[input_batch_dim]; + lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; int64 rhs_linear_index = rhs_linear_spatial_index; + rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; @@ -1148,23 +1187,31 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandleDot(HloInstruction* dot) override { - auto lhs = dot->operand(0); - auto rhs = dot->operand(1); - CHECK(ShapeUtil::IsArray(dot->shape())); - CHECK(ShapeUtil::IsArray(lhs->shape())); - CHECK(ShapeUtil::IsArray(rhs->shape())); + if (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() == 1 && + parent_->use_fast_path_) { + return HandleDot(dot); + } + return HandleDotSlowPath(dot); + } + + template ::value>::type* = nullptr> + Status HandleDot(HloInstruction* dot) { + const HloInstruction* lhs = dot->operand(0); + const HloInstruction* rhs = dot->operand(1); + CHECK(dot->shape().IsArray()); + CHECK(lhs->shape().IsArray()); + CHECK(rhs->shape().IsArray()); const auto& dnums = dot->dot_dimension_numbers(); - const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); - const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + const int64 lhs_rank = lhs->shape().rank(); + const int64 rhs_rank = rhs->shape().rank(); CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); // There must be 1 and only 1 Contracting dimension for lhs and rhs. - CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); - CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); // Contracted dimension sizes must be the same. @@ -1174,8 +1221,56 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << lhs->shape().dimensions(lhs_contracting_dimension) << " rhs contracted dimension: " << rhs->shape().dimensions(rhs_contracting_dimension); - const int64 contracted_dimension_size = - lhs->shape().dimensions(lhs_contracting_dimension); + + // The fast path is for a simple rank 2 dot with default layout operands. + if (lhs_rank == 2 && rhs_rank == 2 && lhs_contracting_dimension == 1 && + rhs_contracting_dimension == 0 && + LayoutUtil::Equal(lhs->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2()) && + LayoutUtil::Equal(rhs->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2()) && + LayoutUtil::Equal(dot->shape().layout(), + LayoutUtil::GetDefaultLayoutForR2())) { + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracting_dimension); + Array2D lhs_array(lhs->shape().dimensions(0), + contracted_dimension_size); + lhs_array.SetValues(lhs_literal.data()); + Array2D rhs_array(contracted_dimension_size, + rhs->shape().dimensions(1)); + rhs_array.SetValues(rhs_literal.data()); + std::unique_ptr> result_array = + HloEvaluator::MatmulArray2D(lhs_array, rhs_array); + Literal result(dot->shape()); + result.PopulateR2FromArray2D(*result_array); + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + } + return HandleDotSlowPath(dot); + } + + template ::value>::type* = nullptr> + Status HandleDot(HloInstruction* dot) { + return HandleDotSlowPath(dot); + } + + Status HandleDotSlowPath(HloInstruction* dot) { + auto lhs = dot->operand(0); + auto rhs = dot->operand(1); + CHECK(dot->shape().IsArray()); + CHECK(lhs->shape().IsArray()); + CHECK(rhs->shape().IsArray()); + + const auto& dnums = dot->dot_dimension_numbers(); + + const auto lhs_rank = lhs->shape().rank(); + const auto rhs_rank = rhs->shape().rank(); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); @@ -1190,7 +1285,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // in lhs_index or rhs_index where the i'th result index should go. absl::InlinedVector, kInlineRank> result_index_locations; - result_index_locations.reserve(lhs_rank + rhs_rank - 2); + result_index_locations.reserve( + (lhs_rank - dnums.lhs_contracting_dimensions_size()) + + (rhs_rank - dnums.rhs_contracting_dimensions_size())); // The first components in the output shape are the LHS and RHS batch // dimensions: @@ -1202,18 +1299,32 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension && + if (!absl::c_linear_search(dnums.lhs_contracting_dimensions(), i) && !absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) { result_index_locations.push_back({&lhs_index[i], nullptr}); } } for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && + if (!absl::c_linear_search(dnums.rhs_contracting_dimensions(), i) && !absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) { result_index_locations.push_back({&rhs_index[i], nullptr}); } } + absl::InlinedVector accumulate_index_sizes; + accumulate_index_sizes.reserve(dnums.lhs_contracting_dimensions_size()); + absl::InlinedVector, kInlineRank> + accumulate_index_locations; + accumulate_index_locations.reserve(dnums.lhs_contracting_dimensions_size()); + for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { + const int64 lhs_dnum = dnums.lhs_contracting_dimensions(i); + const int64 rhs_dnum = dnums.rhs_contracting_dimensions(i); + accumulate_index_locations.push_back( + {&lhs_index[lhs_dnum], &rhs_index[rhs_dnum]}); + const int64 dim_size = lhs->shape().dimensions(lhs_dnum); + accumulate_index_sizes.push_back(dim_size); + } + const int64 total_contraction_size = Product(accumulate_index_sizes); Literal result(dot->shape()); TF_RETURN_IF_ERROR( result.Populate([&](absl::Span result_index) { @@ -1227,13 +1338,30 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } // Accumulates resulting product along the contracted dimension. - for (int64 i = 0; i < contracted_dimension_size; ++i) { - lhs_index[lhs_contracting_dimension] = i; - rhs_index[rhs_contracting_dimension] = i; + absl::InlinedVector accumulate_index( + accumulate_index_sizes.size(), 0); + for (int64 k = 0; k < total_contraction_size; k++) { + for (int64 i = 0; i < accumulate_index_sizes.size(); ++i) { + *(accumulate_index_locations[i].first) = accumulate_index[i]; + *(accumulate_index_locations[i].second) = accumulate_index[i]; + } result_val += static_cast(lhs_literal.Get(lhs_index)) * static_cast(rhs_literal.Get(rhs_index)); + + // If there are no contracting dimension accumulate_index_sizes is + // empty, do not try to count down from -1 to 0 since it is and + // infinite loop. + if (!accumulate_index_sizes.empty()) { + for (int64 i = accumulate_index_sizes.size() - 1; i >= 0; --i) { + int64 value = ++accumulate_index[i]; + if (value != accumulate_index_sizes[i]) { + break; + } + accumulate_index[i] = 0; + } + } } return static_cast(result_val); @@ -1244,10 +1372,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandlePad(HloInstruction* pad) override { - CHECK(ShapeUtil::IsArray(pad->operand(0)->shape())); + CHECK(pad->operand(0)->shape().IsArray()); // Padding value must be scalar. CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); - CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), + CHECK_EQ(pad->operand(0)->shape().rank(), pad->padding_config().dimensions_size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, @@ -1270,9 +1398,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); - std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), - 0); - std::vector target_index(ShapeUtil::Rank(result.shape()), 0); + std::vector input_index(evaluated_operand.shape().rank(), 0); + std::vector target_index(result.shape().rank(), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. @@ -1315,10 +1442,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto operand = dynamic_slice->operand(0); auto start_indices = dynamic_slice->operand(1); auto result_shape = dynamic_slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), - dynamic_slice->dynamic_slice_sizes())); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), + Cast(dynamic_slice)->index_shapes(), + dynamic_slice->dynamic_slice_sizes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1327,33 +1456,39 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { primitive_util::IsIntegralType(start_indices->shape().element_type())); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); switch (start_indices->shape().element_type()) { case S32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case S64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case U32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; case U64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); + DynamicSlice( + operand_literal, + absl::MakeConstSpan(dynamic_slice->operands()).subspan(1), + result_shape)); } break; default: LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " @@ -1373,7 +1508,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( auto inferred_return_shape, ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); + operand->shape(), update->shape(), + Cast(dynamic_update_slice) + ->index_shapes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1384,33 +1521,39 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); switch (start_indices->shape().element_type()) { case S32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case S64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case U32: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; case U64: { TF_ASSIGN_OR_RETURN( parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); + DynamicUpdateSlice( + operand_literal, update_literal, + absl::MakeConstSpan(dynamic_update_slice->operands()) + .subspan(2))); } break; default: LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " @@ -1447,7 +1590,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Literal computed_result = - embedded_evaluator.Evaluate(*computation, arg_literals) + embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on // the same computation. @@ -1505,6 +1648,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); break; } + case C128: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } default: LOG(FATAL) << "HandleMap: unhandled primitive type for " "input operand: " @@ -1515,80 +1662,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleSort(HloInstruction* sort) { - auto keys = sort->operand(0); - TF_RET_CHECK(sort->operand_count() == 1) - << "Typed visitor does not support key-value sort"; - - const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); - int64 sort_dim = sort->dimensions(0); - int64 sort_dim_elements = keys->shape().dimensions(sort_dim); - int64 rank = ShapeUtil::Rank(keys->shape()); - if (rank == 0) { - // Nothing to sort. - parent_->evaluated_[sort] = keys_literal.Clone(); - return Status::OK(); - } - Literal result_literal(keys_literal.shape()); - std::vector zero_base(rank, 0); - std::vector increment(rank, 1); - increment[sort_dim] = sort_dim_elements; - // Iterate through each dimension except 'sort_dim'. - TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()), - increment, [&](absl::Span indices) -> StatusOr { - // Extract a slice from the literal that corresponds to exactly the - // row in dimension 'sort_dim'. - std::vector limit_indices(indices.begin(), indices.end()); - std::for_each(limit_indices.begin(), limit_indices.end(), - [](int64& index) { ++index; }); - limit_indices[sort_dim] = sort_dim_elements; - TF_ASSIGN_OR_RETURN(auto row_to_sort, - keys_literal.Slice(indices, limit_indices) - .Reshape({sort_dim_elements})); - const auto& row_data = row_to_sort.data(); - - std::vector result_data(row_data.begin(), row_data.end()); - std::stable_sort(result_data.begin(), result_data.end(), - [](const NativeT& a, const NativeT& b) { - return SafeLess(a, b); - }); - Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), - {sort_dim_elements})); - sorted_row.PopulateR1(absl::Span(result_data)); - std::vector slice_dimensions(rank, 1); - slice_dimensions[sort_dim] = sort_dim_elements; - TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped, - sorted_row.Reshape(slice_dimensions)); - std::vector start_indices(rank, 0); - TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( - sorted_row_reshaped, start_indices, indices, slice_dimensions)); - return true; - })); - parent_->evaluated_[sort] = std::move(result_literal); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleSort(HloInstruction* sort) { - return InvalidArgument("Unsupported type for Sort"); - } - Status HandleSort(HloInstruction* sort) override { - return HandleSort(sort); + return UnsupportedTypeError(sort); } Status HandleReduce(HloInstruction* hlo) override { HloReduceInstruction* reduce = Cast(hlo); int64 num_args = reduce->inputs().size(); - bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape()); + bool has_tuple_output = reduce->shape().IsTuple(); absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); @@ -1619,7 +1700,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // All args and results have the same dimensions, so pick an arbitrary one. const Shape& arg_shape = arg_literals[0]->shape(); - const Shape& result_shape = ShapeUtil::IsTuple(reduce->shape()) + const Shape& result_shape = reduce->shape().IsTuple() ? reduce->shape().tuple_shapes(0) : reduce->shape(); const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions()); @@ -1708,7 +1789,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](Literal& literal) { return &literal; }); TF_ASSIGN_OR_RETURN(Literal computed_result, - embedded_evaluator.Evaluate( + embedded_evaluator.Evaluate( *function, embedded_operands_ptrs)); // Clear visit states so that we can use the evaluator again on // the same computation. @@ -1786,7 +1867,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); - int64 rank = ShapeUtil::Rank(operand_literal.shape()); + int64 rank = operand_literal.shape().rank(); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); DimensionVector source_index(rank, 0); @@ -1824,8 +1905,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val_literal.Set({}, *selected_val); Literal computed_result = embedded_evaluator - .Evaluate( - *select, {&selected_val_literal, &curr_val_literal}) + .Evaluate(*select, + {&selected_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); bool selected = !computed_result.Get({}); if (selected) { @@ -1846,9 +1927,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { scattered_literal.Set({}, scattered); Literal computed_result = embedded_evaluator - .Evaluate( - *scatter, - {&source_literal_scatter, &scattered_literal}) + .Evaluate(*scatter, + {&source_literal_scatter, &scattered_literal}) .ConsumeValueOrDie(); result.Set(operand_index, computed_result.Get({})); // Clear visit states so that the we can use the evaluator again @@ -1898,7 +1978,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { operand->shape().element_type(), window_dimension_sizes); DimensionVector window_index(window.dimensions_size()); - DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + DimensionVector operand_index(operand_literal.shape().rank()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); Literal result(reduce_window->shape()); @@ -1922,8 +2002,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(result_val); Literal computed_result = embedded_evaluator - .Evaluate( - *function, {&result_val_literal, &curr_val_literal}) + .Evaluate(*function, + {&result_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again @@ -2285,9 +2365,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(updates.Get(update_index)); Literal updated_result = embedded_evaluator - .Evaluate( - *scatter->to_apply(), - {&result_value_literal, &update_value_literal}) + .Evaluate(*scatter->to_apply(), + {&result_value_literal, &update_value_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on the // same computation. @@ -2329,7 +2408,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); - const int64 rank = ShapeUtil::Rank(operand->shape()); + const int64 rank = operand->shape().rank(); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); auto func = [&](absl::Span out_index) { DimensionVector operand_index(rank); @@ -2357,7 +2436,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value || std::is_same::value)>::type* = nullptr> Status HandleClz(HloInstruction* clz) { - return InvalidArgument("Unsupported type for Clz"); + return UnsupportedTypeError(clz); } template ::value || is_complex_t::value>::type* = nullptr> Status HandleSin(HloInstruction* sin) { - return InvalidArgument("Unsupported type for Sin"); + return UnsupportedTypeError(sin); } Status HandleSin(HloInstruction* sin) override { @@ -2425,7 +2504,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || is_complex_t::value>::type* = nullptr> Status HandleCos(HloInstruction* cos) { - return InvalidArgument("Unsupported type for Cos"); + return UnsupportedTypeError(cos); } Status HandleCos(HloInstruction* cos) override { @@ -2526,7 +2605,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Double not supported for reduce precision"); + return InvalidArgument("Double is not supported for reduce precision"); } template < @@ -2534,46 +2613,159 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value || is_complex_t::value>::type* = nullptr> Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Unsupported type for reduce precision"); + return UnsupportedTypeError(reduce_precision); } Status HandleReducePrecision(HloInstruction* reduce_precision) override { return HandleReducePrecision(reduce_precision); } - template ::value || - std::is_floating_point::value>::type* = nullptr> + template < + typename NativeT, + typename std::enable_if< + std::is_same::value || + std::is_same::value || + std::is_integral::value || is_complex_t::value || + std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); + const int64 iota_size = iota->shape().dimensions(iota->iota_dimension()); // Avoid using std::vector since std::vector does not convert to // absl::Span. - absl::InlinedVector data( - iota->shape().dimensions(iota->iota_dimension())); - std::iota(data.begin(), data.end(), 0); + absl::InlinedVector data(iota_size); + // We don't use std::iota for two reasons: + // + // (1) std:iota does not support bfloat16 and float16. + // + // (2) std::iota saturates for floating point types when the value is not + // representable, but the definition of HLO iota is the value as a + // 64-bit integer cast to the native type. + for (int64 i = 0; i < iota_size; ++i) { + // static_cast is required for Eigen::half (F16). + data[i] = static_cast(i); + } auto result = LiteralUtil::CreateR1(data); - if (ShapeUtil::Rank(iota->shape()) > 1) { + if (iota->shape().rank() > 1) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[iota], result.Broadcast(iota->shape(), {iota->iota_dimension()})); } else { - TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); + TF_RET_CHECK(iota->shape().rank() == 1); parent_->evaluated_[iota] = std::move(result); } return Status::OK(); } + template < + typename NativeT, + typename std::enable_if< + !(std::is_same::value || + std::is_same::value || + std::is_integral::value || is_complex_t::value || + std::is_floating_point::value)>::type* = nullptr> + Status HandleIota(HloInstruction* iota) { + return UnsupportedTypeError(iota); + } + Status HandleIota(HloInstruction* iota) override { + return HandleIota(iota); + } + template ::value || std::is_floating_point::value)>::type* = nullptr> - Status HandleIota(HloInstruction* iota) { - return InvalidArgument("Unsupported type for iota"); + Status HandleRng(HloInstruction* random) { + return UnsupportedTypeError(random); } - Status HandleIota(HloInstruction* iota) override { - return HandleIota(iota); + template ::value)>::type* = nullptr> + Status HandleRng(HloInstruction* random) { + RandomDistribution distribution = random->random_distribution(); + const auto result_shape = random->shape(); + Literal result(result_shape); + + switch (distribution) { + case RNG_UNIFORM: { + const Literal& low = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& high = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + std::uniform_real_distribution generator( + low.Get({}), high.Get({})); + + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + return generator(parent_->engine_); + })); + break; + } + case RNG_NORMAL: { + const Literal& mean = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& stddev = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + std::normal_distribution generator(mean.Get({}), + stddev.Get({})); + + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + return generator(parent_->engine_); + })); + break; + } + default: + return UnimplementedStrCat("The distribution ", + RandomDistribution_Name(distribution), + " is not implemented."); + } + parent_->evaluated_[random] = std::move(result); + return Status::OK(); + } + template ::value)>::type* = + nullptr> + Status HandleRng(HloInstruction* random) { + RandomDistribution distribution = random->random_distribution(); + const auto result_shape = random->shape(); + Literal result(result_shape); + + switch (distribution) { + case RNG_UNIFORM: { + const Literal& low = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& high = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + // Note std::uniform_int_distribution assumes interval is closed, i.e., + // [low, high], but we want [low, high) instead. Hence high-1 is used as + // the upper range. + std::uniform_int_distribution generator( + low.Get({}), high.Get({}) - 1); + + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span /*indexes*/) { + return static_cast(generator(parent_->engine_)); + })); + break; + } + case RNG_NORMAL: { + return Unimplemented( + "Normal distribution is not supported for integral types."); + } + default: + return UnimplementedStrCat("The distribution ", + RandomDistribution_Name(distribution), + " is not implemented."); + } + parent_->evaluated_[random] = std::move(result); + return Status::OK(); + } + Status HandleRng(HloInstruction* random) override { + return HandleRng(random); } private: @@ -2587,7 +2779,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // // This lets you calculate LI given the multidimensional indices in any order. static DimensionVector MakeDimMultipliers(const Shape& shape) { - DimensionVector v(ShapeUtil::Rank(shape)); + DimensionVector v(shape.rank()); int64 scale = 1; for (auto dim : LayoutUtil::MinorToMajor(shape)) { v[dim] = scale; @@ -2604,7 +2796,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Shape& window_shape, const Window& window, const Shape& base_shape, const absl::Span& window_count_index, const std::function&)>& f) { - const int64 rank = ShapeUtil::Rank(base_shape); + const int64 rank = base_shape.rank(); DimensionVector window_index(rank); std::fill(window_index.begin(), window_index.end(), 0); do { @@ -2635,12 +2827,27 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr DynamicSlice(const Literal& operand_literal, - const Literal& start_indices_literal, - const Shape& result_shape) { - auto start_indices_typed = start_indices_literal.data(); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); + StatusOr DynamicSlice( + const Literal& operand_literal, + absl::Span start_indices, + const Shape& result_shape) { + std::vector start; + // TODO(b/118437727): Remove the R1 code-path. Note that to distinguish + // between the cases, this currently assumes there is at least 1 index. That + // is wrong in the general case, because for scalar indices, if the operand + // is scalar, then there are no indices. This problem with resolve itself. + const HloInstruction* first_index = start_indices[0]; + if (first_index->shape().rank() == 1) { + auto start_indices_typed = + parent_->GetEvaluatedLiteralFor(first_index).data(); + start = std::vector(start_indices_typed.begin(), + start_indices_typed.end()); + } else { + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); + } + } // Clamp the start indices so the slice is in-bounds w.r.t the operand. for (int64 i = 0; i < start.size(); ++i) { @@ -2666,14 +2873,28 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr DynamicUpdateSlice(const Literal& operand_literal, - const Literal& update_literal, - const Literal& start_indices_literal) { + StatusOr DynamicUpdateSlice( + const Literal& operand_literal, const Literal& update_literal, + absl::Span start_indices) { auto result = operand_literal.Clone(); - auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result.shape()); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); + const auto rank = result.shape().rank(); + std::vector start; + // TODO(b/118437727): Remove the R1 code-path. Note that to distinguish + // between the cases, this currently assumes there is at least 1 index. That + // is wrong in the general case, because for scalar indices, if the operand + // is scalar, then there are no indices. This problem with resolve itself. + const HloInstruction* first_index = start_indices[0]; + if (first_index->shape().rank() == 1) { + auto start_indices_typed = + parent_->GetEvaluatedLiteralFor(first_index).data(); + start = std::vector(start_indices_typed.begin(), + start_indices_typed.end()); + } else { + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); + } + } // Clamp the update start indices so the slice is in-bounds w.r.t the // operand. for (int64 i = 0; i < rank; ++i) { @@ -2790,6 +3011,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f48140ee4f6ca9415bef80c83664213109dbf9f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex128.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int16.cc new file mode 100644 index 0000000000000000000000000000000000000000..e54285a1577a3f3c97fba5ba6c2f969299ab599e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int16.cc @@ -0,0 +1,22 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint16.cc new file mode 100644 index 0000000000000000000000000000000000000000..cc708952d20a00429944c8388a84a0e610c2f38f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint16.cc @@ -0,0 +1,22 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index 5be9dba3aa49d63c580cd6f5800f608667826b6a..df06cf8c53ec8407f8b44c9126ed4fb5409f8ef3 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -45,7 +45,7 @@ TEST_F(HloExecutionProfileTest, Basic) { auto shape_size_function = [&](const Shape& shape) { const int64 pointer_size = 8; - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return pointer_size; } return ShapeUtil::ByteSizeOf(shape, pointer_size); diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc index c919dbd82d3668c477bf37074f1d56f8cb7d9506..862b2029718bbd802b69d789b66683a4edfa2367 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -25,7 +26,9 @@ namespace xla { namespace { -StatusOr ReplaceGetSize(HloInstruction* instr) { +StatusOr ReplaceGetSize( + HloInstruction* instr, + const DynamicDimensionInference* dynamic_dimension_inference) { if (instr->opcode() != HloOpcode::kGetDimensionSize) { return false; } @@ -36,10 +39,18 @@ StatusOr ReplaceGetSize(HloInstruction* instr) { instr->operand(0)->shape(), instr->dimension())); TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)); TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32)); - uint32 size = instr->operand(0)->shape().dimensions(instr->dimension()); - HloInstruction* new_instr = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + HloInstruction* operand = instr->mutable_operand(0); + int64 dim = instr->dimension(); + HloInstruction* dynamic_size = + dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); + if (dynamic_size != nullptr) { + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); + } else { + uint32 size = instr->operand(0)->shape().dimensions(dim); + HloInstruction* new_instr = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + } return true; } @@ -48,10 +59,13 @@ StatusOr ReplaceGetSize(HloInstruction* instr) { StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { bool changed = false; HloProto proto; + TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, + DynamicDimensionInference::Run(module)); *proto.mutable_hlo_module() = module->ToProto(); for (auto* computation : module->computations()) { for (auto instruction : computation->instructions()) { - TF_ASSIGN_OR_RETURN(bool replaced, ReplaceGetSize(instruction)); + TF_ASSIGN_OR_RETURN(bool replaced, + ReplaceGetSize(instruction, &inference)); changed = changed || replaced; } } diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h index 30f44c23a835b3bcc935caaa917e040e07c4e703..9aa79fe66b665c48ec871c4188e44ba2056de3ad 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h @@ -21,7 +21,9 @@ limitations under the License. namespace xla { -// Pass to replace a kGetDimensionSize instruction with a constant instruction. +// Pass to replace a kGetDimensionSize instruction with a hlo instruction +// representing the dynamic size if the dimension is dynamic, otherwise a +// constant instruction representing the static size. class HloGetDimensionSizeRewriter : public HloModulePass { public: absl::string_view name() const override { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 0745c735e5263f4f3a8d22cef0adf1010f28b9b4..46ee99923ee9b6d852e6190cc8de6afe0b99457e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -24,9 +24,9 @@ limitations under the License. #include #include #include -#include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -112,11 +113,6 @@ class NodeFilter { result == kSomeUsersOmitted; } - bool ShowFusionSubcomputation(const HloInstruction* instr) const { - CHECK_EQ(instr->opcode(), HloOpcode::kFusion); - return Show(instr) && !SomeOrAllOperandsOmitted(instr); - } - private: std::function filter_; }; @@ -241,34 +237,28 @@ string HtmlLikeStringSanitize(absl::string_view s) { // it to a short string lets us tell the user what the subcomputation is without // drawing it as a graph. optional MatchTrivialComputation(const HloComputation* computation) { + namespace m = match; + if (computation->instruction_count() != 3) { return nullopt; } - HloInstruction* root = computation->root_instruction(); - if (root->operand_count() != 2) { - return nullopt; - } - - // Check that both of the operands to the root are parameters. - const HloInstruction* operand0 = root->operand(0); - const HloInstruction* operand1 = root->operand(1); - if (operand0->opcode() != HloOpcode::kParameter || - operand1->opcode() != HloOpcode::kParameter) { - return nullopt; - } - - // Check that the two operands of root are param0 and param1. All of the - // opcodes we recognize are commutative, so we're OK with either order. - auto n0 = operand0->parameter_number(); - auto n1 = operand1->parameter_number(); - if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) { + const HloInstruction *param0, *param1; + if (!Match(root, m::Op() + .WithNumOperands(2) + .WithShape(m::Shape().IsEffectiveScalar()) + .WithBinaryOperandsAnyOrder( + m::Parameter(¶m0, 0) + .WithShape(m::Shape().IsEffectiveScalar()), + m::Parameter(¶m1, 1) + .WithShape(m::Shape().IsEffectiveScalar())))) { return nullopt; } - // If the params are reversed, check that the operation being performed is - // commutative. - if (n0 == 1) { + // If the params are reversed (i.e. operand0 is param1 and operand1 is + // param0), check that the operation being performed is commutative. + if (root->operand(0) == param1) { + CHECK_EQ(root->operand(1), param0); switch (root->opcode()) { case HloOpcode::kLe: case HloOpcode::kGe: @@ -280,13 +270,6 @@ optional MatchTrivialComputation(const HloComputation* computation) { } } - // Check that the root and params are all effective scalars. - if (!ShapeUtil::IsEffectiveScalar(root->shape()) || - !ShapeUtil::IsEffectiveScalar(operand0->shape()) || - !ShapeUtil::IsEffectiveScalar(operand1->shape())) { - return nullopt; - } - // If we recognize the root's opcode, we've successfully pattern-matched! switch (root->opcode()) { case HloOpcode::kAdd: @@ -397,7 +380,7 @@ class HloDotDumper { // Each HloInstruction dumped gets a monotically-increasing node ID. This // must start at 1, because that's where graphviz's accounting starts. int64 next_node_id_ = 1; - std::unordered_map node_ids_; + absl::flat_hash_map node_ids_; // The "root" tag doesn't have an associated HloInstruction pointer, so we // need to store it outside the map. @@ -414,7 +397,7 @@ class HloDotDumper { // Each HloComputation that's emitted gets a monotonically-increasing ID. int64 next_cluster_id_ = 1; - std::unordered_map cluster_ids_; + absl::flat_hash_map cluster_ids_; // Edges to print from Footer(). Edges come at the end because graphviz is // unhappy if an edge from a subcomputation to a node in the outer computation @@ -424,7 +407,7 @@ class HloDotDumper { // When coloring by sharding information, we track the sharding string // representation to color association, by round-robin the color schemes. - std::unordered_map + absl::flat_hash_map sharding_colors_; int64 next_shard_color_ = 0; }; @@ -578,8 +561,8 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { } // Show the subcomputation if we're showing any of its members. - return std::any_of( - subcomp->instructions().begin(), subcomp->instructions().end(), + return absl::c_any_of( + subcomp->instructions(), [&](const HloInstruction* instr) { return filter_.Show(instr); }); } @@ -750,17 +733,16 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { return true; } const int kMinUsersToOmit = 3; - return instr->opcode() == HloOpcode::kParameter && - ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && - std::count_if(instr->users().begin(), instr->users().end(), - [&](const HloInstruction* user) { - return filter_.Show(user); - }) > kMinUsersToOmit && - std::all_of(instr->users().begin(), instr->users().end(), - [&](const HloInstruction* user) { - return !filter_.Show(user) || - user->opcode() == HloOpcode::kGetTupleElement; - }); + return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() && + !instr->IsFused() && + absl::c_count_if(instr->users(), + [&](const HloInstruction* user) { + return filter_.Show(user); + }) > kMinUsersToOmit && + absl::c_all_of(instr->users(), [&](const HloInstruction* user) { + return !filter_.Show(user) || + user->opcode() == HloOpcode::kGetTupleElement; + }); } string HloDotDumper::DumpInstruction(const HloInstruction* instr) { @@ -833,7 +815,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // Print the literal value of constants with <= K elements. optional elem_count; - if (ShapeUtil::IsArray(shape)) { + if (shape.IsArray()) { elem_count = 1; for (int64 dim : shape.dimensions()) { *elem_count *= dim; @@ -917,12 +899,11 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { // the same color as a parameter. Unless the merged-in parameter is a // parameter to a fusion node that is bound to a constant -- these aren't // "real" parameters from the user's perspective. - if (std::any_of(instr->operands().begin(), instr->operands().end(), - [&](const HloInstruction* operand) { - return operand->opcode() == HloOpcode::kParameter && - ShouldMergeIntoUsers(operand) && - TryGetFusionParameterConstant(operand) == nullptr; - })) { + if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) { + return operand->opcode() == HloOpcode::kParameter && + ShouldMergeIntoUsers(operand) && + TryGetFusionParameterConstant(operand) == nullptr; + })) { return parameter_color; } @@ -1047,7 +1028,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kMap: case HloOpcode::kGetDimensionSize: return kGray; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: @@ -1056,6 +1037,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kRecvDone: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kReplicaId: return kBrown; case HloOpcode::kCall: case HloOpcode::kConditional: @@ -1303,7 +1285,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, int64 radius) { // First, find the neighborhood of nodes with distance from root <= radius. // These nodes are our initial set of "normal" nodes. - std::unordered_map nodes; + absl::flat_hash_map nodes; std::deque> worklist; worklist.push_back({root, 0}); while (!worklist.empty()) { @@ -1324,7 +1306,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, // are not interesting to the graph at hand. if (instr == root || instr->opcode() != HloOpcode::kTuple) { for (const HloInstruction* operand : instr->operands()) { - if (!nodes.count(operand)) { + if (!nodes.contains(operand)) { worklist.push_back({operand, depth + 1}); } } @@ -1352,7 +1334,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, continue; } for (const HloInstruction* user : instr->users()) { - if (!nodes.count(user)) { + if (!nodes.contains(user)) { worklist.push_back({user, depth + 1}); } } @@ -1361,7 +1343,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, auto is_displayed = [&](const HloInstruction* instr) { // Constants are displayed inline with their users; they're never omitted. // Nodes in subcomputations are always shown. - return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant || + return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant || instr->parent() != root->parent(); }; @@ -1372,12 +1354,11 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, NodeFilterResult& filter_result = kv.second; const auto& operands = instr->operands(); - if (std::any_of(operands.begin(), operands.end(), is_displayed) && - !std::all_of(operands.begin(), operands.end(), is_displayed)) { + if (absl::c_any_of(operands, is_displayed) && + !absl::c_all_of(operands, is_displayed)) { // Mark nodes with some operands omitted appropriately. filter_result = kSomeOperandsOmitted; - } else if (!operands.empty() && - std::none_of(operands.begin(), operands.end(), is_displayed)) { + } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) { // Mark nodes with *all* operands omitted appropriately. filter_result = kOmitNodeOperands; } @@ -1385,8 +1366,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their // users made it into the graph. if (filter_result == kSomeUsersOmitted && - std::all_of(instr->users().begin(), instr->users().end(), - is_displayed)) { + absl::c_all_of(instr->users(), is_displayed)) { filter_result = kNormalNode; } } @@ -1491,14 +1471,15 @@ string ExportGraph(const string& graph, GraphRendererInterface::GraphKind graph_kind, const DebugOptions& debug_options) { string path = debug_options.xla_hlo_graph_path(); - if (!path.empty()) { + if (!path.empty() && !debug_options.xla_hlo_dump_as_html()) { return SaveGraph(graph, graph_kind, path); } else { auto graph_renderer = GraphRendererRegistry::Default()->GetDefaultRenderer(); CHECK(graph_renderer != nullptr) << "No registered renderer for the HLO graph. " - "Use --xla_hlo_graph_path=PATH to export to local file system"; + "Use --xla_hlo_graph_path=PATH --xla_hlo_dump_as_html=false to " + "export to local file system"; return graph_renderer->RenderGraph(graph, graph_kind, debug_options); } } @@ -1606,5 +1587,145 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, return graph_url; } +string WrapDotInHTML(const string& dot) { + static const char html_prefix[] = R"html( + + + + + + + + + + + +
+ + + +)html"; + + return html_prefix + dot + html_suffix; +} + +string RenderDotAsHTMLFile(const string& dot, + const DebugOptions& debug_options) { + string html = WrapDotInHTML(dot); + + auto env = tensorflow::Env::Default(); + std::vector dirs; + string output_dir = debug_options.xla_hlo_graph_path(); + if (output_dir.empty()) { + env->GetLocalTempDirectories(&dirs); + } else { + dirs.push_back(output_dir); + } + // Try each directory, as they might be full, have inappropriate + // permissions or have different problems at times. + string output; + for (const string& dir : dirs) { + string filename = tensorflow::io::JoinPath(dir, "graph-"); + if (env->CreateUniqueFileName(&filename, ".html")) { + output = filename; + break; + } + } + if (output.empty()) { + LOG(FATAL) << "Failed to create unique output file name."; + } + TF_CHECK_OK(tensorflow::WriteStringToFile(env, output, html)); + return "file://" + output; +} + } // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index de1eefab776f9c3d2c73959a5cd267e938a78a32..8e51454ef1cf992386cc7325e32705c08bf7712f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -81,6 +81,12 @@ string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix = true); +// Renders DOT graph as inline SVG and saves it in an HTML file in a temprary +// directory or directory specified via --xla_hlo_graph_path. Returns the file +// URI pointing to the file. +string RenderDotAsHTMLFile(const string& dot, + const DebugOptions& debug_options); + // Graph renderers may be added using a registration mechanism, e.g.: // XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100) // The renderer with the highest numeric priority value is used. diff --git a/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc new file mode 100644 index 0000000000000000000000000000000000000000..84c4cf18df69816c611f4eb159ba247320ebc20e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_graph_html_renderer.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation of an DOT graph renderer that uses Javascript to render DOT to +// SVG in a browser. + +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +namespace hlo_graph_dumper { +namespace { + +class GraphHtmlRenderer : public GraphRendererInterface { + public: + string RenderGraph(const string& graph, GraphKind graph_kind, + const DebugOptions& debug_options) override { + switch (graph_kind) { + case DOT_GRAPH: + return RenderDotAsHTMLFile(graph, debug_options); + default: + LOG(FATAL) << "Only DOT graphs can be rendered"; + } + } +}; + +XLA_REGISTER_GRAPH_RENDERER(GraphHtmlRenderer); + +} // namespace +} // namespace hlo_graph_dumper +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc index 6e1597fd03db0a78aa560340b7b9b64fe500df0c..b01c00121b3363630b83a1e49d0027a66f3a9e1a 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -17,22 +17,34 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" namespace xla { + +bool HloInputOutputAliasConfig::OutputHasAlias( + const ShapeIndex& output_index) const { + return alias_.element(output_index).has_value(); +} + Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + const ShapeIndex& param_index, + AliasKind kind) { + TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias) + << kind; TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) << absl::StrCat("Tring to set up alias at ", output_index.ToString(), " which is an invalid index for shape ", ShapeUtil::HumanString(alias_.shape())); + TF_RET_CHECK(param_number >= 0) << param_number; + TF_RET_CHECK(!OutputHasAlias(output_index)) + << "Output index " << output_index << " already has an alias setup"; // Output can't be aliased with multiple parameters. TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat( "Trying to set up output alias for param %lld at %s but failed: output " "index %s is already aliased with param %lld at %s", param_number, param_index.ToString(), output_index.ToString(), - alias_.element(output_index)->first, - alias_.element(output_index)->second.ToString()); + alias_.element(output_index)->parameter_number, + alias_.element(output_index)->parameter_index.ToString()); (*alias_.mutable_element(output_index)) = - std::make_pair(param_number, param_index); + Alias(kind, param_number, param_index); VLOG(4) << "Set up alias between output index " << output_index.ToString() << " and parameter " << param_index << " at index " << param_index.ToString(); @@ -42,15 +54,24 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { HloInputOutputAliasProto result; alias_.ForEachElement( - [&](const ShapeIndex& index, - const absl::optional>& data) { + [&](const ShapeIndex& index, const absl::optional& data) { if (data) { HloInputOutputAliasProto::AliasEntryProto entry; + switch (data->kind) { + case AliasKind::kUserAlias: + entry.set_kind(HloInputOutputAliasProto::USER_ALIAS); + break; + case AliasKind::kSystemAlias: + entry.set_kind(HloInputOutputAliasProto::SYSTEM_ALIAS); + break; + default: + LOG(FATAL) << "Unknown alias kind " << data->kind; + } for (int64 i : index) { entry.add_output_shape_index(i); } - entry.set_parameter_number(data->first); - for (int64 i : data->second) { + entry.set_parameter_number(data->parameter_number); + for (int64 i : data->parameter_index) { entry.add_parameter_shape_index(i); } result.add_entries()->Swap(&entry); @@ -66,14 +87,18 @@ StatusOr HloInputOutputAliasConfig::CreateFromProto( proto.entries()) { ShapeIndex output_index(entry.output_shape_index().begin(), entry.output_shape_index().end()); - int64 param_number = entry.parameter_number(); ShapeIndex param_index(entry.parameter_shape_index().begin(), entry.parameter_shape_index().end()); + // Handle backward compatibility with existing protos, which only knew of + // system aliases. + AliasKind kind = AliasKind::kSystemAlias; + if (entry.kind() == HloInputOutputAliasProto::USER_ALIAS) { + kind = AliasKind::kUserAlias; + } TF_RETURN_IF_ERROR( - result.SetUpAlias(output_index, param_number, param_index)); + result.SetUpAlias(output_index, param_number, param_index, kind)); } - return result; } @@ -81,45 +106,44 @@ string HloInputOutputAliasConfig::ToString() const { std::vector pieces; pieces.push_back("HloInputOutputAliasConfig"); - ForEachAlias([&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { + ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) { + const char* kind = alias.kind == AliasKind::kUserAlias ? "USER" : "SYSTEM"; pieces.push_back(absl::StrFormat( - " OutputIndex %s is aliased with parameter %lld at %s:", - output_index.ToString(), param_number, param_index.ToString())); + " OutputIndex %s is aliased (kind=%s) with parameter %lld at %s:", + output_index.ToString(), kind, alias.parameter_number, + alias.parameter_index.ToString())); }); - return absl::StrJoin(pieces, "\n"); } -bool HloInputOutputAliasConfig::ParameterHasAlias( +HloInputOutputAliasConfig::AliasKind +HloInputOutputAliasConfig::ParameterAliasKind( int64 param_number, const ShapeIndex& param_index) const { - bool output = false; + AliasKind kind = AliasKind::kNoAlias; alias_.ForEachElement( - [&](const xla::ShapeIndex&, - absl::optional> alias) { - if (alias && alias->first == param_number && - alias->second == param_index) { - output = true; + [&](const xla::ShapeIndex&, absl::optional alias) { + if (alias && alias->parameter_number == param_number && + alias->parameter_index == param_index) { + kind = alias->kind; } }); - return output; + return kind; } absl::optional HloInputOutputAliasConfig::GetAliasedOutput( int64 param_number, const ShapeIndex& param_index) const { absl::optional output; alias_.ForEachElement( - [&](const xla::ShapeIndex& output_index, - absl::optional> alias) { - if (alias && alias->first == param_number && - alias->second == param_index) { + [&](const xla::ShapeIndex& output_index, absl::optional alias) { + if (alias && alias->parameter_number == param_number && + alias->parameter_index == param_index) { output = output_index; } }); return output; } -absl::optional> +absl::optional HloInputOutputAliasConfig::GetAliasedParameter( const ShapeIndex& output_index) const { CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)); @@ -128,10 +152,9 @@ HloInputOutputAliasConfig::GetAliasedParameter( void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { alias_.ForEachElement( - [&](const ShapeIndex& output_index, - absl::optional> aliased) { + [&](const ShapeIndex& output_index, absl::optional aliased) { if (aliased) { - fn(output_index, aliased->first, aliased->second); + fn(output_index, *aliased); } }); } @@ -139,10 +162,9 @@ void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { Status HloInputOutputAliasConfig::ForEachAliasWithStatus( AliasFnWithStatus fn) const { return alias_.ForEachElementWithStatus( - [&](const ShapeIndex& output_index, - absl::optional> aliased) { + [&](const ShapeIndex& output_index, absl::optional aliased) { if (aliased) { - TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second)); + TF_RETURN_IF_ERROR(fn(output_index, *aliased)); } return Status::OK(); }); @@ -158,20 +180,19 @@ Status HloInputOutputAliasConfig::Verify( param_has_seen.emplace_back(param->shape()); } return ForEachAliasWithStatus([&](const ShapeIndex& output_index, - int64 param_number, - const ShapeIndex& param_index) -> Status { + const Alias& alias) -> Status { const HloInstruction* root = entry->root_instruction(); - TF_RET_CHECK(0 <= param_number); - TF_RET_CHECK(entry->num_parameters() > param_number); + TF_RET_CHECK(0 <= alias.parameter_number); + TF_RET_CHECK(entry->num_parameters() > alias.parameter_number); const Shape& param_shape = - entry->parameter_instruction(param_number)->shape(); + entry->parameter_instruction(alias.parameter_number)->shape(); const Shape& output_shape = root->shape(); - TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, alias.parameter_index)); TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index)); const Shape& param_subshape = - ShapeUtil::GetSubshape(param_shape, param_index); + ShapeUtil::GetSubshape(param_shape, alias.parameter_index); const Shape& output_subshape = ShapeUtil::GetSubshape(output_shape, output_index); TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape)); @@ -182,19 +203,20 @@ Status HloInputOutputAliasConfig::Verify( "Expected aliased input %lld at index %s and output at index %s to " "have the same size. Input sub-shape is %s with size %lld, output " "sub-shape is %s with size %lld", - param_number, param_index.ToString(), output_index.ToString(), + alias.parameter_number, alias.parameter_index.ToString(), + output_index.ToString(), ShapeUtil::HumanStringWithLayout(param_subshape), size_func(param_subshape), ShapeUtil::HumanStringWithLayout(output_subshape), size_func(output_subshape)); } - // Check each param_number and param_index pair only show up once. No - // input can be aliased with output buffers. - TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false); - - *(param_has_seen[param_number].mutable_element(param_index)) = true; - + // Check each alias.parameter_number and alias.parameter_index pair only + // show up once. No input can be aliased with output buffers. + TF_RET_CHECK(param_has_seen[alias.parameter_number].element( + alias.parameter_index) == false); + *(param_has_seen[alias.parameter_number].mutable_element( + alias.parameter_index)) = true; return Status::OK(); }); } diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h index 439676b1546c4af7f781fb80bccffd5248309b0f..cd13c7a3ac7afe03fb99ed3114bdc6ac0f8ad6a7 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_tree.h" @@ -31,21 +32,54 @@ class HloModule; // parameter index in the entry computation. class HloInputOutputAliasConfig { public: + // The kind of aliases which can be set. A kUserAlias is one setup at + // compilation time by the user, and has to be respected. A kSystemAlias one + // might be setup by the compiler, if it decides it is convenient to do so. + enum AliasKind { + kNoAlias, + kUserAlias, + kSystemAlias, + }; + + // Defines the alias information for a given output buffer. A given output + // buffer shape index can refer only to one parameter+index. + struct Alias { + Alias(AliasKind kind, int64 parameter_number, ShapeIndex parameter_index) + : kind(kind), + parameter_number(parameter_number), + parameter_index(std::move(parameter_index)) {} + + AliasKind kind; + int64 parameter_number; + ShapeIndex parameter_index; + }; + HloInputOutputAliasConfig() = default; - explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {} + explicit HloInputOutputAliasConfig(Shape output_shape) + : alias_(output_shape) {} virtual ~HloInputOutputAliasConfig() = default; // Sets up alias config from `output_index` to `param_index` at // `param_number`. Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index); + const ShapeIndex& param_index, AliasKind kind); + + // Returns the kind of alias for the given parameter number and parameter + // index. If no alias exists, AliasKind::kNoAlias is returned. + AliasKind ParameterAliasKind(int64 param_number, + const ShapeIndex& param_index) const; // Returns true if the given parameter is aliased with one of the output // buffers. bool ParameterHasAlias(int64 param_number, - const ShapeIndex& param_index) const; + const ShapeIndex& param_index) const { + return ParameterAliasKind(param_number, param_index) != AliasKind::kNoAlias; + } + + // Checks whether the provided output index has already been aliased. + bool OutputHasAlias(const ShapeIndex& output_index) const; // (De)Serializes an HloInputOutoutAliasConfig to/from an // HloInputOutoutAliasProto. @@ -63,19 +97,17 @@ class HloInputOutputAliasConfig { // Returns the number of parameter and index of the parameter buffer that the // given output buffer index is aliased with. A nullopt is returned if there // is no parameter is aliased with the specific output. - absl::optional> GetAliasedParameter( + absl::optional GetAliasedParameter( const ShapeIndex& output_index) const; using AliasFn = - std::function; + std::function; // Iterates through each aliased output and input. void ForEachAlias(AliasFn fn) const; using AliasFnWithStatus = - std::function; + std::function; // Verifies that the given config is valid for the given module. // Specifically, the config's input and output should be in-bound and size of @@ -90,9 +122,10 @@ class HloInputOutputAliasConfig { private: // A ShapeTree which indicates the list of buffers that's expected to be // aliased. The key on this shape tree represents the output index. The value - // is a pair of parameter number and index into the buffer. If the value is - // nullopt, it means there is no parameter aliasing for this output. - ShapeTree>> alias_; + // is an Alias data structure which defines the input parameter coordinates. + // If the value is nullopt, it means there is no parameter aliasing for this + // output. + ShapeTree> alias_; }; std::ostream& operator<<(std::ostream& out, diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc index aeb9b0fdc8b6cca87731a2d4aae25120af6c3215..a46a107723de30176241aae01b268a8c10d991d3 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -45,11 +45,12 @@ class HloInputOutputAliasConfigTest : public HloTestBase { EXPECT_TRUE(aliased_output); EXPECT_EQ(aliased_output.value(), output_index); - absl::optional> aliased_param = + absl::optional aliased_param = config.GetAliasedParameter(output_index); EXPECT_TRUE(aliased_param); - EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index)); + EXPECT_EQ(aliased_param->parameter_number, param_number); + EXPECT_EQ(aliased_param->parameter_index, param_index); } void expect_not_aliased(const ShapeIndex& output_index, int64 param_number, @@ -60,11 +61,12 @@ class HloInputOutputAliasConfigTest : public HloTestBase { EXPECT_FALSE(aliased_output && aliased_output == output_index); - absl::optional> aliased_param = + absl::optional aliased_param = config.GetAliasedParameter(output_index); - EXPECT_FALSE(aliased_param && aliased_param->first == param_number && - aliased_param->second == param_index); + EXPECT_FALSE(aliased_param && + aliased_param->parameter_number == param_number && + aliased_param->parameter_index == param_index); } }; @@ -84,8 +86,10 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); expect_aliased(/*output_index=*/{0}, /*param_number=*/1, /*param_index=*/{}, config); @@ -114,11 +118,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{0})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{1})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); expect_aliased(/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}, config); @@ -149,11 +157,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); @@ -176,8 +188,10 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); @@ -200,11 +214,15 @@ ENTRY main { HloInputOutputAliasConfig config( module->entry_computation()->root_instruction()->shape()); - TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, - /*param_index=*/{})); + TF_ASSERT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); - ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, - /*param_index=*/{})); + ASSERT_IS_NOT_OK(config.SetUpAlias( + /*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, + /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias)); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 21b1dbc1676cccd2fe5b331a1f9d6ff5e3a73fcd..1b677bc25c1d3dfd0205c3e0dfbf1fd30c646fd4 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -82,15 +83,14 @@ StatusOr> HloInstruction::CreateFromProto( return computation_map.at(proto.called_computation_ids(index)); }; - TF_RET_CHECK(std::all_of( - proto.operand_ids().begin(), proto.operand_ids().end(), - [&instruction_map](int64 id) { return instruction_map.contains(id); })) + TF_RET_CHECK( + absl::c_all_of(proto.operand_ids(), + [&](int64 id) { return instruction_map.contains(id); })) << proto.name() << " instruction contains invalid operand id(s)"; - TF_RET_CHECK(std::all_of( - proto.called_computation_ids().begin(), - proto.called_computation_ids().end(), - [&computation_map](int64 id) { return computation_map.contains(id); })) + TF_RET_CHECK( + absl::c_all_of(proto.called_computation_ids(), + [&](int64 id) { return computation_map.contains(id); })) << proto.name() << " instruction references invalid computation id(s)"; Shape shape(proto.shape()); @@ -311,7 +311,7 @@ StatusOr> HloInstruction::CreateFromProto( shape, operands(0), proto.exponent_bits(), proto.mantissa_bits()); break; case HloOpcode::kInfeed: { - TF_RET_CHECK(ShapeUtil::IsTuple(shape) && + TF_RET_CHECK(shape.IsTuple() && (ShapeUtil::TupleElementCount(shape) == 2)) << "Infeed should have a tuple shape with 2 operands, but has: " << shape; @@ -333,20 +333,20 @@ StatusOr> HloInstruction::CreateFromProto( proto.outfeed_config()); break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "CrossReplicaSum should have 1 called computation but sees " + << "AllReduce should have 1 called computation but sees " << proto.called_computation_ids_size(); absl::optional all_reduce_id; if (proto.all_reduce_id() > 0) { all_reduce_id = proto.all_reduce_id(); } - instruction = CreateCrossReplicaSum( + instruction = CreateAllReduce( shape, all_operands(), computations(0), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end()), - /*barrier=*/proto.cross_replica_sum_barrier(), + /*barrier=*/proto.all_reduce_barrier(), /*all_reduce_id=*/all_reduce_id); break; } @@ -372,6 +372,13 @@ StatusOr> HloInstruction::CreateFromProto( CreateCollectivePermute(shape, operands(0), source_target_pairs); break; } + case HloOpcode::kReplicaId: { + TF_RET_CHECK(proto.operand_ids_size() == 0) + << "ReplicaId instruction should have 0 operand but sees " + << proto.operand_ids_size(); + instruction = CreateReplicaId(); + break; + } case HloOpcode::kConvolution: { TF_RET_CHECK(proto.operand_ids_size() == 2) << "Convolution instruction should have 2 operands but sees " @@ -383,7 +390,8 @@ StatusOr> HloInstruction::CreateFromProto( proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = CreateConvolve( shape, operands(0), operands(1), - std::max(proto.feature_group_count(), 1), proto.window(), + std::max(proto.feature_group_count(), 1), + std::max(proto.batch_group_count(), 1), proto.window(), proto.convolution_dimension_numbers(), precision_config); break; } @@ -438,6 +446,9 @@ StatusOr> HloInstruction::CreateFromProto( static_cast(instruction.get()) ->set_feature_group_count( std::max(static_cast(proto.feature_group_count()), 1LL)); + static_cast(instruction.get()) + ->set_batch_group_count( + std::max(static_cast(proto.batch_group_count()), 1LL)); break; case HloOpcode::kPad: TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -448,13 +459,43 @@ StatusOr> HloInstruction::CreateFromProto( CreatePad(shape, operands(0), operands(1), proto.padding_config()); break; case HloOpcode::kDynamicSlice: { - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "DynamicSlice instruction should have 2 operands but sees " - << proto.operand_ids_size(); std::vector slice_sizes(proto.dynamic_slice_sizes_size()); absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + TF_RET_CHECK(proto.operand_ids_size() >= 1) + << "DynamicSlice instruction should have at least 1 operands but " + "sees " + << proto.operand_ids_size(); + // TODO(b/118437727): Old form, make the check unconditional. + if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) { + auto expected_operands = 1 + operands(0)->shape().rank(); + TF_RET_CHECK(proto.operand_ids_size() == expected_operands) + << "DynamicSlice instruction should have " << expected_operands + << " operands, but has " << proto.operand_ids_size(); + } + const auto& operand_vector = all_operands(); + instruction = CreateDynamicSlice( + shape, operands(0), absl::MakeSpan(operand_vector).subspan(1), + slice_sizes); + break; + } + case HloOpcode::kDynamicUpdateSlice: { + TF_RET_CHECK(proto.operand_ids_size() >= 2) + << "DynamicUpdateSlice instruction should have at least 2 operands " + "but sees " + << proto.operand_ids_size(); + // TODO(b/118437727): Old form, make the check unconditional. + if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) { + auto expected_operands = 2 + operands(0)->shape().rank(); + TF_RET_CHECK(proto.operand_ids_size() == expected_operands) + << "DynamicUpdateSlice instruction should have " + << expected_operands << " operands, but has " + << proto.operand_ids_size(); + } + const auto& operand_vector = all_operands(); instruction = - CreateDynamicSlice(shape, operands(0), operands(1), slice_sizes); + CreateDynamicUpdateSlice(shape, operands(0), operands(1), + absl::MakeSpan(operand_vector).subspan(2)); + break; } case HloOpcode::kGather: { @@ -569,6 +610,11 @@ StatusOr> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + + TF_RET_CHECK(proto.id() >= 0) + << "Instruction with negative id: " << proto.id(); + TF_RET_CHECK(proto.id() <= INT_MAX) + << "Instruction with id > INT_MAX: " << proto.id(); instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { @@ -619,7 +665,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, absl::Span operands) { if (opcode == HloOpcode::kCopy) { // It is impossible to copy an opaque shape, we don't know how big it is. - CHECK(!ShapeUtil::IsOpaque(shape)); + CHECK(!shape.IsOpaque()); } auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (auto operand : operands) { @@ -729,12 +775,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) { return absl::make_unique( - shape, lhs, rhs, feature_group_count, window, dimension_numbers, - precision_config); + shape, lhs, rhs, feature_group_count, batch_group_count, window, + dimension_numbers, precision_config); } /* static */ std::unique_ptr HloInstruction::CreateFft( @@ -761,8 +807,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape, shape, operand, exponent_bits, mantissa_bits); } -/* static */ std::unique_ptr -HloInstruction::CreateCrossReplicaSum( +/* static */ std::unique_ptr HloInstruction::CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, @@ -787,6 +832,11 @@ HloInstruction::CreateCollectivePermute( shape, operand, source_target_pairs); } +/* static */ std::unique_ptr HloInstruction::CreateReplicaId() { + return absl::WrapUnique( + new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {}))); +} + /* static */ std::unique_ptr HloInstruction::CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config) { @@ -903,23 +953,19 @@ HloInstruction::CreateAddDependency(HloInstruction* data_operand, } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( - const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, absl::Span slice_sizes) { return absl::make_unique( shape, operand, start_indices, slice_sizes); } /* static */ std::unique_ptr -HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, - HloInstruction* operand, - HloInstruction* update, - HloInstruction* start_indices) { - auto instruction = absl::WrapUnique( - new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(update); - instruction->AppendOperand(start_indices); - return instruction; +HloInstruction::CreateDynamicUpdateSlice( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices) { + return absl::make_unique( + shape, operand, update, start_indices); } /* static */ std::unique_ptr HloInstruction::CreateConcatenate( @@ -1035,7 +1081,7 @@ HloInstruction::CreateBroadcastSequence( const std::function)>& adder) { CHECK(ShapeUtil::IsScalar(operand->shape()) || - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)); + operand->shape().rank() == output_shape.rank()); Shape broadcast_shape = ShapeUtil::ChangeElementType( output_shape, operand->shape().element_type()); // Do explicit broadcast for scalar. @@ -1051,7 +1097,7 @@ HloInstruction::CreateBroadcastSequence( // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { + for (int i = 0; i < operand->shape().rank(); i++) { if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand->shape().dimensions(i)); @@ -1128,7 +1174,7 @@ HloInstruction::CreateBroadcastSequence( void HloInstruction::set_single_sharding(const HloSharding& sharding) { CHECK(!sharding.IsTuple()) << sharding; - if (ShapeUtil::IsTuple(shape())) { + if (shape().IsTuple()) { set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape()))); } else { set_sharding(sharding); @@ -1160,7 +1206,7 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kOutfeed: case HloOpcode::kTrace: return true; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: return all_reduce_id().has_value(); default: return false; @@ -1283,7 +1329,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kParameter: case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: @@ -1378,9 +1424,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateReshape(shape, new_operands[0]); break; case HloOpcode::kDynamicUpdateSlice: - CHECK_EQ(new_operands.size(), 3); clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], - new_operands[2]); + new_operands.subspan(2)); break; case HloOpcode::kTuple: clone = CreateTuple(new_operands); @@ -1408,6 +1453,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 2); clone = CreateAddDependency(new_operands[0], new_operands[1]); break; + case HloOpcode::kReplicaId: + CHECK_EQ(new_operands.size(), 0); + clone = CreateReplicaId(); + break; } // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); @@ -1542,12 +1591,10 @@ HloInstruction::InstructionVector HloInstruction::unique_operands() const { Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); - if (std::find(control_successors_.begin(), control_successors_.end(), - instruction) == control_successors_.end()) { + if (!absl::c_linear_search(control_successors_, instruction)) { control_successors_.push_back(instruction); - TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(), - instruction->control_predecessors_.end(), - this) == instruction->control_predecessors_.end()); + TF_RET_CHECK( + !absl::c_linear_search(instruction->control_predecessors_, this)); instruction->control_predecessors_.push_back(this); } return Status::OK(); @@ -1679,6 +1726,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReal: case HloOpcode::kRemainder: case HloOpcode::kReshape: + case HloOpcode::kReplicaId: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: @@ -1740,7 +1788,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReducePrecision: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kConvolution: @@ -1760,7 +1808,12 @@ bool HloInstruction::IdenticalSlowPath( return false; } -uint64 HloInstruction::Hash() const { +static uint64 HashOperand(const HloInstruction* hlo) { + return ShapeUtil::Hash(hlo->shape()); +} + +uint64 HloInstruction::Hash( + const std::function& hash_operand) const { using tensorflow::Hash64Combine; uint64 hash_value = Hash64Combine(0, static_cast(opcode())); @@ -1769,7 +1822,7 @@ uint64 HloInstruction::Hash() const { if (!IsCrossModuleAllReduce()) { if (!operands().empty()) { for (size_t i = 0; i < operands().size(); ++i) { - hash_value = Hash64Combine(hash_value, operand(i)->Hash()); + hash_value = Hash64Combine(hash_value, hash_operand(operand(i))); } } } @@ -1778,6 +1831,11 @@ uint64 HloInstruction::Hash() const { return hash_value; } +uint64 HloInstruction::Hash() const { + // Use HashOperand as an argument to prevent non-termination. + return Hash(HashOperand); +} + uint64 HloInstruction::InnerHash() const { return 13; } void HloInstruction::RemoveUser(HloInstruction* user) { @@ -1786,7 +1844,7 @@ void HloInstruction::RemoveUser(HloInstruction* user) { user_set_.erase(set_it); // This is linear in the number of the users, but a vector provides a stable // iteration order and much faster traversal. - auto vec_it = std::find(users_.begin(), users_.end(), user); + auto vec_it = absl::c_find(users_, user); CHECK(vec_it != users_.end()); users_.erase(vec_it); } @@ -1804,8 +1862,7 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, RemoveUser(user); - TF_RET_CHECK( - std::count(user->operands_.begin(), user->operands_.end(), this) >= 0); + TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0); std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); @@ -1818,6 +1875,16 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, Status HloInstruction::ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand) { + auto old_operand = operand(operand_num); + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), + new_operand->shape())) + << old_operand->shape() << " is not compatible with " + << new_operand->shape(); + return ReplaceOperandWithDifferentShape(operand_num, new_operand); +} + +Status HloInstruction::ReplaceOperandWithDifferentShape( + int64 operand_num, HloInstruction* new_operand) { TF_RET_CHECK(operand_num >= 0); TF_RET_CHECK(operand_num < operand_count()); HloInstruction* old_operand = mutable_operand(operand_num); @@ -1825,17 +1892,12 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, return Status::OK(); } - TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), - new_operand->shape())) - << old_operand->shape() << " is not compatible with " - << new_operand->shape(); operands_[operand_num] = new_operand; VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with " << new_operand->name() << ", was " << old_operand->name(); - if (std::find(operands_.begin(), operands_.end(), old_operand) == - operands_.end()) { + if (!absl::c_linear_search(operands_, old_operand)) { old_operand->RemoveUser(this); } new_operand->AddUser(this); @@ -1843,6 +1905,14 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, } Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { + TF_RET_CHECK( + ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape())) + << shape() << " is not compatible with " << new_producer->shape(); + return ReplaceAllUsesWithDifferentShape(new_producer); +} + +Status HloInstruction::ReplaceAllUsesWithDifferentShape( + HloInstruction* new_producer) { bool new_producer_is_user = false; for (HloInstruction* user : users()) { if (user == new_producer) { @@ -1867,7 +1937,8 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { AddUser(new_producer); } if (parent_ && parent_->root_instruction() == this) { - parent_->set_root_instruction(new_producer); + parent_->set_root_instruction(new_producer, + /*accept_different_shape=*/true); } return Status::OK(); @@ -1879,7 +1950,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; @@ -1898,7 +1969,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; @@ -2056,7 +2127,11 @@ bool HloInstruction::IsElementwiseImpl( } bool HloInstruction::IsCrossModuleAllReduce() const { - return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id(); + return opcode() == HloOpcode::kAllReduce && all_reduce_id(); +} + +bool HloInstruction::IsCrossReplicaAllReduce() const { + return opcode() == HloOpcode::kAllReduce && !all_reduce_id(); } string HloInstruction::ToStringWithCanonicalNameMap( @@ -2167,7 +2242,7 @@ std::vector HloInstruction::ExtraAttributesToString( } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce || - opcode() == HloOpcode::kCrossReplicaSum || + opcode() == HloOpcode::kAllReduce || opcode() == HloOpcode::kScatter) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); @@ -2203,7 +2278,7 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kScatter: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); @@ -2400,12 +2475,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleConvolution(this); case HloOpcode::kFft: return visitor->HandleFft(this); - case HloOpcode::kCrossReplicaSum: - return visitor->HandleCrossReplicaSum(this); + case HloOpcode::kAllReduce: + return visitor->HandleAllReduce(this); case HloOpcode::kAllToAll: return visitor->HandleAllToAll(this); case HloOpcode::kCollectivePermute: return visitor->HandleCollectivePermute(this); + case HloOpcode::kReplicaId: + return visitor->HandleReplicaId(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2806,7 +2883,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { } return UseKind::kReuse; case HloOpcode::kDynamicUpdateSlice: - // Dynamic-update-slice reuses only operand 2 (start_indices). + // Dynamic-update-slice reuses only start_indices. if (i == 0 || i == 1) { return UseKind::kUse; } @@ -2859,10 +2936,10 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding) { bool has_interior_padding = - std::any_of(padding.dimensions().begin(), padding.dimensions().end(), - [](const PaddingConfig::PaddingConfigDimension& dim) { - return dim.interior_padding() != 0; - }); + absl::c_any_of(padding.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.interior_padding() != 0; + }); return StrJoin( padding.dimensions(), "x", [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { @@ -3256,13 +3333,12 @@ HloInstruction::source_target_pairs() const { return Cast(this)->source_target_pairs(); } -string HloInstruction::cross_replica_sum_barrier() const { - return Cast(this)->cross_replica_sum_barrier(); +string HloInstruction::all_reduce_barrier() const { + return Cast(this)->all_reduce_barrier(); } -void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - return Cast(this)->set_cross_replica_sum_barrier( - barrier); +void HloInstruction::set_all_reduce_barrier(const string& barrier) { + return Cast(this)->set_all_reduce_barrier(barrier); } absl::optional HloInstruction::all_reduce_id() const { @@ -3308,6 +3384,18 @@ void HloInstruction::set_feature_group_count(int64 feature_group_count) { feature_group_count); } +int64 HloInstruction::batch_group_count() const { + if (auto convolution = DynCast(this)) { + return convolution->batch_group_count(); + } + return Cast(this)->batch_group_count(); +} + +void HloInstruction::set_batch_group_count(int64 batch_group_count) { + Cast(this)->set_batch_group_count( + batch_group_count); +} + HloComputation* HloInstruction::select() const { return Cast(this)->select(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a54716217d6bbc5c0601f5d9ff7bf4072a6b30f5..c11d29d33e918a363a7df5c4ec4e53dbf407e71e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -426,7 +426,7 @@ class HloInstruction { // and window describes how the filter is applied to lhs. static std::unique_ptr CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); @@ -462,9 +462,7 @@ class HloInstruction { // `all_reduce_id`: for Allreduce nodes from different modules, if they have // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will // not be applied cross modules. - // - // TODO(b/117564385): Rename this to AllReduce. - static std::unique_ptr CreateCrossReplicaSum( + static std::unique_ptr CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, const std::vector& replica_groups, @@ -496,6 +494,9 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs); + // Creates an instruction that returns a U32 replica ID. + static std::unique_ptr CreateReplicaId(); + // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr CreateConvert(const Shape& shape, @@ -560,13 +561,14 @@ class HloInstruction { // 'slice_sizes'. static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, - HloInstruction* start_indices, absl::Span slice_sizes); + absl::Span start_indices, + absl::Span slice_sizes); // Creates a dynamic update slice instruction, which updates a slice // of 'operand' with 'update' and 'start_indices'. static std::unique_ptr CreateDynamicUpdateSlice( const Shape& shape, HloInstruction* operand, HloInstruction* update, - HloInstruction* start_indices); + absl::Span start_indices); // Creates a concatenate instruction, where the operands are concatenated on // the provided dimension. @@ -909,6 +911,14 @@ class HloInstruction { // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO instructions, // with respect to HloInstruction::Identical() method. + // + // Uses hash_operand function to compute hash values of its operands. + // At the very top level, hash_operand should be non-recursive to prevent + // non-termination. + uint64 Hash( + const std::function& hash_operand) const; + + // Calls the above method with non-recursive hash_operand function. uint64 Hash() const; // Returns whether the instruction has a constant operand. @@ -922,11 +932,16 @@ class HloInstruction { // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); - // Replaces the specified operand with new_operand. + // Replaces the specified operand with new_operand. The old and new operands + // must have compatible shapes ignoring floating-point precision. // // This function does NOT remove duplicated operands even if this instruction // is a fusion, so that the existing operand numbers do not change. - Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); + Status ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand); + + // Same as ReplaceOperandWith(), but new_operand can have a different shape. + Status ReplaceOperandWithDifferentShape(int64 operand_num, + HloInstruction* new_operand); // Replaces all uses of this instruction with the new producer. If // new_producer is a user of this instruction then new_producer remains a use @@ -935,10 +950,16 @@ class HloInstruction { // If this instruction is the root of its computation, sets the computation's // root to new_producer. // + // The new producer must have a compatible shape ignoring floating-point + // precision. + // // If a user is a fusion instruction, this function will remove any duplicated // operands of it which could be created due to this replacement. Status ReplaceAllUsesWith(HloInstruction* new_producer); + // Same as ReplaceAllUsesWith, but new_producer can have a different shape. + Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer); + // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when // complete. If ignore_control_predecessors is true, instructions only @@ -1174,9 +1195,12 @@ class HloInstruction { // Returns true if this instruction is elementwise on all its operands. bool IsElementwise() const; - // Returns true if this is an cross module all-reduce instrucion. + // Returns true if this is a cross module all-reduce instruction. bool IsCrossModuleAllReduce() const; + // Returns true if this is a cross-replica all-reduce instruction. + bool IsCrossReplicaAllReduce() const; + // Returns true if this elementwise instruction implicitly broadcasts operand // `operand_idx`. // @@ -1448,9 +1472,9 @@ class HloInstruction { // Delegates to HloCollectivePermuteInstruction::source_target_pairs. const std::vector>& source_target_pairs() const; - // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. - string cross_replica_sum_barrier() const; - void set_cross_replica_sum_barrier(const string& barrier); + // Delegates to HloAllReduceInstruction::all_reduce_barrier. + string all_reduce_barrier() const; + void set_all_reduce_barrier(const string& barrier); // Delegates to HloAllReduceInstruction::all_reduce_id. absl::optional all_reduce_id() const; @@ -1484,6 +1508,11 @@ class HloInstruction { void set_feature_group_count(int64 feature_group_count); + // The number of batch groups. Must be a divisor of the input batch dimension + int64 batch_group_count() const; + + void set_batch_group_count(int64 batch_group_count); + // Delegates to HloSelectAndScatterInstruction::select. HloComputation* select() const; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 8048e332cb57747286758b75773b29ba154aa888..35f031f29a7aca8db7ebe2fbcfdcebb7a778d703 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -55,13 +56,13 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { } Status HandleParameter(HloInstruction* parameter) override { - EXPECT_EQ(0, count_.count(parameter)); + EXPECT_FALSE(count_.contains(parameter)); count_[parameter] = GetCountsForNode(parameter); return Status::OK(); } Status HandleConstant(HloInstruction* constant) override { - EXPECT_EQ(0, count_.count(constant)); + EXPECT_FALSE(count_.contains(constant)); count_[constant] = GetCountsForNode(constant); return Status::OK(); } @@ -69,25 +70,25 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { Status HandleAdd(HloInstruction* add) override { auto lhs = add->operand(0); auto rhs = add->operand(1); - EXPECT_EQ(0, count_.count(add)); - EXPECT_GT(count_.count(lhs), 0); - EXPECT_GT(count_.count(rhs), 0); + EXPECT_FALSE(count_.contains(add)); + EXPECT_TRUE(count_.contains(lhs)); + EXPECT_TRUE(count_.contains(rhs)); count_[add] = GetCountsForNode(add); return Status::OK(); } Status HandleNegate(HloInstruction* negate) override { auto operand = negate->operand(0); - EXPECT_EQ(0, count_.count(negate)); - EXPECT_GT(count_.count(operand), 0); + EXPECT_FALSE(count_.contains(negate)); + EXPECT_TRUE(count_.contains(operand)); count_[negate] = GetCountsForNode(negate); return Status::OK(); } Status HandleMap(HloInstruction* map) override { - EXPECT_EQ(0, count_.count(map)); + EXPECT_FALSE(count_.contains(map)); for (HloInstruction* arg : map->operands()) { - EXPECT_GT(count_.count(arg), 0); + EXPECT_TRUE(count_.contains(arg)); } count_[map] = GetCountsForNode(map); return Status::OK(); @@ -96,9 +97,9 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); - EXPECT_EQ(0, count_.count(reduce)); - EXPECT_GT(count_.count(arg), 0); - EXPECT_GT(count_.count(init_value), 0); + EXPECT_FALSE(count_.contains(reduce)); + EXPECT_TRUE(count_.contains(arg)); + EXPECT_TRUE(count_.contains(init_value)); count_[reduce] = GetCountsForNode(reduce); return Status::OK(); } @@ -128,7 +129,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { } // Counters for HLOs. Maps HLO to a NumOpsAndUsers. - std::unordered_map count_; + absl::flat_hash_map count_; }; TEST_F(HloInstructionTest, BasicProperties) { @@ -137,7 +138,7 @@ TEST_F(HloInstructionTest, BasicProperties) { EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32)); EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32)); - EXPECT_EQ(0, parameter->operand_count()); + EXPECT_FALSE(parameter->operand_count()); } TEST_F(HloInstructionTest, UserWithTwoOperands) { @@ -981,9 +982,9 @@ TEST_F(HloInstructionTest, FunctionVisitor) { module->AddEntryComputation(builder.Build()); int visit_num = 0; - std::unordered_map visit_order; + absl::flat_hash_map visit_order; EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) { - EXPECT_EQ(0, visit_order.count(inst)); + EXPECT_FALSE(visit_order.contains(inst)); visit_order[inst] = visit_num; visit_num++; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 1ea02cf9c03866a598bec0e5356f0eb31ad27755..785206bf7753abaef5788365fe10217b8b74ccc6 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -42,11 +42,9 @@ using absl::StrJoin; bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, const HloInstruction* operand) { std::vector operand_indices = instruction->OperandIndices(operand); - return std::all_of( - operand_indices.begin(), operand_indices.end(), - [instruction](int64 operand_index) { - return instruction->IsElementwiseOnOperand(operand_index); - }); + return absl::c_all_of(operand_indices, [instruction](int64 operand_index) { + return instruction->IsElementwiseOnOperand(operand_index); + }); } string PrecisionConfigToString(const PrecisionConfig& precision_config) { @@ -363,9 +361,9 @@ HloAllReduceInstruction::HloAllReduceInstruction( HloComputation* reduce_computation, const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id) - : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands, + : HloCollectiveInstruction(HloOpcode::kAllReduce, shape, operands, replica_groups), - cross_replica_sum_barrier_(barrier), + all_reduce_barrier_(barrier), all_reduce_id_(all_reduce_id) { AppendComputation(reduce_computation); } @@ -381,16 +379,25 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { if (all_reduce_id_) { proto.set_all_reduce_id(*all_reduce_id_); } - proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + proto.set_all_reduce_barrier(all_reduce_barrier_); return proto; } +bool HloAllReduceInstruction::IsNoop() const { + for (auto replica_group : replica_groups()) { + if (replica_group.replica_ids().size() != 1) { + return false; + } + } + return !all_reduce_id(); +} + std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { std::vector result = HloCollectiveInstruction::ExtraAttributesToStringImpl(options); - if (!cross_replica_sum_barrier().empty()) { - result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + if (!all_reduce_barrier().empty()) { + result.push_back(StrCat("barrier=\"", all_reduce_barrier(), "\"")); } if (all_reduce_id_) { result.push_back(StrCat("all_reduce_id=", *all_reduce_id_)); @@ -405,8 +412,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( const auto& casted_other = static_cast(other); return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && eq_computations(to_apply(), casted_other.to_apply()) && - cross_replica_sum_barrier() == - casted_other.cross_replica_sum_barrier() && + all_reduce_barrier() == casted_other.all_reduce_barrier() && all_reduce_id() == casted_other.all_reduce_id(); } @@ -415,8 +421,8 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique( - shape, new_operands, to_apply(), replica_groups(), - cross_replica_sum_barrier(), all_reduce_id()); + shape, new_operands, to_apply(), replica_groups(), all_reduce_barrier(), + all_reduce_id()); } HloAllToAllInstruction::HloAllToAllInstruction( @@ -735,7 +741,7 @@ HloMapInstruction::HloMapInstruction(const Shape& shape, AppendComputation(map_computation); // TODO(b/65689298) Remove code below once Map is generalized to accept // arbitrary map dimensions. - dimensions_.resize(ShapeUtil::Rank(shape)); + dimensions_.resize(shape.rank()); std::iota(dimensions_.begin(), dimensions_.end(), 0); } @@ -815,8 +821,7 @@ std::vector HloSliceInstruction::ExtraAttributesToStringImpl( std::vector bounds; bounds.reserve(slice_starts_.size()); const bool omit_stride = - std::all_of(slice_strides_.begin(), slice_strides_.end(), - [](int64 stride) { return stride == 1; }); + absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; }); for (int i = 0; i < slice_starts_.size(); ++i) { string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); bounds.push_back( @@ -867,7 +872,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index) { Shape* mutable_array_subshape = ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); - CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + CHECK(mutable_array_subshape->IsArray()); // Normally array_subshape will always have a layout, but this invariant is // temporarily broken in LayoutAssignment::AssignLayouts. @@ -901,11 +906,11 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( string operands; // For constants, show the actual value in place of an empty operand list. if (literal_.has_value() && - ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || + ((shape().IsArray() && ShapeUtil::ElementsIn(shape()) <= 10) || options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToString(); + string tmp = literal().ToStringWithoutShape(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); std::vector v = absl::StrSplit(tmp, ' '); bool first = true; @@ -1052,8 +1057,7 @@ HloInstruction* HloFusionInstruction::AddFusionOperand( void HloFusionInstruction::MergeFusionInstruction( HloFusionInstruction* instruction_to_merge) { - CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != - operands().end()); + CHECK(absl::c_linear_search(operands(), instruction_to_merge)); // Clone the instruction from which to merge fused instructions. std::unique_ptr cloned = instruction_to_merge->Clone(); HloFusionInstruction* cloned_fusion = @@ -1220,8 +1224,8 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( // corresponding fused parameter instruction. Renumber parameters as // necessary to make parameter numbers consistent with their index in the // fused_parameter_ vector. - bool in_operand_list = std::find(operands().begin(), operands().end(), - instruction_to_fuse) != operands().end(); + bool in_operand_list = + absl::c_linear_search(operands(), instruction_to_fuse); CHECK(add_output || in_operand_list); if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { // We assume all uses of a kTuple operation are GTE ops, not another @@ -1325,7 +1329,7 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( if (newly_created_tuple_instr) { HloInstruction* new_instr = parent()->AddInstruction( HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); - TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); + TF_CHECK_OK(ReplaceAllUsesWithDifferentShape(new_instr)); } int64 index = tuple_elements.size(); if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { @@ -1372,8 +1376,14 @@ bool HloFusionInstruction::IdenticalSlowPath( other.fused_instructions_computation()); } +static uint64 HashOperandRecursive(const HloInstruction* hlo) { + return hlo->Hash(HashOperandRecursive); +} + uint64 HloFusionInstruction::InnerHash() const { - return fused_instructions_computation()->Hash(); + // Use HashOperandRecursive to recursively compute hash on inner operands. + return fused_instructions_computation()->root_instruction()->Hash( + HashOperandRecursive); } std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( @@ -1649,11 +1659,12 @@ std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config) : HloInstruction(HloOpcode::kConvolution, shape), feature_group_count_(feature_group_count), + batch_group_count_(batch_group_count), window_(window), convolution_dimension_numbers_(dimension_numbers), precision_config_(precision_config) { @@ -1684,6 +1695,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; proto.set_feature_group_count(feature_group_count_); + proto.set_batch_group_count(batch_group_count_); *proto.mutable_precision_config() = precision_config_; return proto; } @@ -1700,6 +1712,10 @@ std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( extra.push_back(StrCat("feature_group_count=", feature_group_count_)); } + if (batch_group_count_ != 1) { + extra.push_back(StrCat("batch_group_count=", batch_group_count_)); + } + string precision_config_string = PrecisionConfigToString(precision_config_); if (!precision_config_string.empty()) { extra.push_back(precision_config_string); @@ -1717,6 +1733,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath( if (feature_group_count_ != other.feature_group_count()) { return false; } + if (batch_group_count_ != other.batch_group_count()) { + return false; + } return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), @@ -1731,8 +1750,9 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( - shape, new_operands[0], new_operands[1], feature_group_count_, window(), - convolution_dimension_numbers_, precision_config_); + shape, new_operands[0], new_operands[1], feature_group_count_, + batch_group_count_, window(), convolution_dimension_numbers_, + precision_config_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -1834,6 +1854,7 @@ HloCustomCallInstruction::HloCustomCallInstruction( custom_call_target_(custom_call_target.begin(), custom_call_target.end()), opaque_(opaque.begin(), opaque.end()), feature_group_count_(1), + batch_group_count_(1), layout_constrained_(false) { for (auto operand : operands) { AppendOperand(operand); @@ -1848,6 +1869,7 @@ HloCustomCallInstruction::HloCustomCallInstruction( custom_call_target_(custom_call_target.begin(), custom_call_target.end()), opaque_(opaque.begin(), opaque.end()), feature_group_count_(1), + batch_group_count_(1), layout_constrained_(true), operand_shapes_with_layout_(operand_shapes_with_layout.begin(), operand_shapes_with_layout.end()) { @@ -1868,6 +1890,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { proto.set_custom_call_target(custom_call_target_); proto.set_custom_call_opaque(opaque_); proto.set_feature_group_count(feature_group_count_); + proto.set_batch_group_count(batch_group_count_); if (layout_constrained()) { proto.set_constrain_layout(true); for (const Shape& shape : operand_shapes_with_layout_) { @@ -1891,6 +1914,9 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( if (feature_group_count_ != 1) { extra.push_back(StrCat("feature_group_count=", feature_group_count_)); } + if (batch_group_count_ != 1) { + extra.push_back(StrCat("batch_group_count=", batch_group_count_)); + } // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. @@ -1934,6 +1960,9 @@ bool HloCustomCallInstruction::IdenticalSlowPath( if (feature_group_count_ != casted_other.feature_group_count_) { return false; } + if (batch_group_count_ != casted_other.batch_group_count_) { + return false; + } return custom_call_target_ == casted_other.custom_call_target_ && opaque_ == casted_other.opaque_; } @@ -1951,6 +1980,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); } cloned->set_feature_group_count(feature_group_count_); + cloned->set_batch_group_count(batch_group_count_); return std::move(cloned); } @@ -1994,12 +2024,44 @@ std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( HloDynamicSliceInstruction::HloDynamicSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes) - : HloInstruction(HloOpcode::kDynamicSlice, shape), + : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape), + dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { + AppendOperand(operand); + AppendOperand(start_indices); +} + +HloDynamicSliceInstruction::HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, + absl::Span slice_sizes) + : HloDynamicIndexInstruction(HloOpcode::kDynamicSlice, shape), dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { AppendOperand(operand); + for (HloInstruction* index : start_indices) { + AppendOperand(index); + } +} + +HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices) + : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) { + AppendOperand(operand); + AppendOperand(update); AppendOperand(start_indices); } +HloDynamicUpdateSliceInstruction::HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices) + : HloDynamicIndexInstruction(HloOpcode::kDynamicUpdateSlice, shape) { + AppendOperand(operand); + AppendOperand(update); + for (HloInstruction* index : start_indices) { + AppendOperand(index); + } +} + HloInstructionProto HloDynamicSliceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); for (int64 slice_size : dynamic_slice_sizes_) { @@ -2025,9 +2087,14 @@ std::unique_ptr HloDynamicSliceInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 2); - return absl::make_unique( - shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); + if (new_operands.size() == 2 && new_operands[1]->shape().rank() == 1) { + // TODO(b/118437727): Old form, remove this path. + return absl::make_unique( + shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); + } else { + return absl::make_unique( + shape, new_operands[0], new_operands.subspan(1), dynamic_slice_sizes_); + } } HloGatherInstruction::HloGatherInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index b5c28137a145667a977d39c9d3c40c6d36a8436e..1b4a94753cda8aba8d50836b9d51b7c3fd5807f6 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -242,14 +242,10 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id); - // Returns the barrier config used for the CrossReplicaSum implementation of + // Returns the barrier config used for the AllReduce implementation of // each backend. - string cross_replica_sum_barrier() const { - return cross_replica_sum_barrier_; - } - void set_cross_replica_sum_barrier(string barrier) { - cross_replica_sum_barrier_ = barrier; - } + string all_reduce_barrier() const { return all_reduce_barrier_; } + void set_all_reduce_barrier(string barrier) { all_reduce_barrier_ = barrier; } absl::optional all_reduce_id() const { return all_reduce_id_; } void set_all_reduce_id(const absl::optional& all_reduce_id); @@ -257,6 +253,10 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns true if the AllReduce does no communication, so it's equivalent + // to a mem copy. + bool IsNoop() const; + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -270,8 +270,8 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // The string representation of the barrier config used for CrossReplicaSum. - string cross_replica_sum_barrier_; + // The string representation of the barrier config used for AllReduce. + string all_reduce_barrier_; // For Allreduce nodes from different modules, if they have the same // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be @@ -933,7 +933,7 @@ class HloConvolutionInstruction : public HloInstruction { public: explicit HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - int64 feature_group_count, const Window& window, + int64 feature_group_count, int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, const PrecisionConfig& precision_config); const Window& window() const override { return window_; } @@ -949,6 +949,10 @@ class HloConvolutionInstruction : public HloInstruction { // dimension and output feature dimension. int64 feature_group_count() const { return feature_group_count_; } + // The number of feature groups. Must be a divisor of the input batch + // dimension. + int64 batch_group_count() const { return batch_group_count_; } + // Returns the information used to tell the implementation information about // what sort of precision is requested. The meaning of the field is backend // specific. At the moment, it is only supported for kConvolution and kDot. @@ -977,6 +981,9 @@ class HloConvolutionInstruction : public HloInstruction { // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count_; + // The number of feature groups. Must be a divisor of the input batch + // dimension. + int64 batch_group_count_; // Describes the window used for a convolution. Window window_; // Describes the dimension numbers used for a convolution. @@ -1099,7 +1106,11 @@ class HloCustomCallInstruction : public HloInstruction { void set_feature_group_count(int64 feature_group_count) { feature_group_count_ = feature_group_count; } + void set_batch_group_count(int64 batch_group_count) { + batch_group_count_ = batch_group_count; + } int64 feature_group_count() const { return feature_group_count_; } + int64 batch_group_count() const { return batch_group_count_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1134,6 +1145,7 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr convolution_dimension_numbers_; // The number of feature groups. This is used for grouped convolutions. int64 feature_group_count_; + int64 batch_group_count_; // Whether the result and operand layouts are constrained. bool layout_constrained_; // For layout-constrained custom calls, this vector holds the shape with @@ -1171,12 +1183,38 @@ class HloPadInstruction : public HloInstruction { PaddingConfig padding_config_; }; -class HloDynamicSliceInstruction : public HloInstruction { +class HloDynamicIndexInstruction : public HloInstruction { + public: + explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape) + : HloInstruction(opcode, shape) {} + virtual int64 first_index_operand_number() const = 0; + + // Returns a subspan of operands which represent the start indices. + absl::Span index_operands() const { + return absl::MakeSpan(operands()).subspan(first_index_operand_number()); + } + + // Returns the shapes of the index operands. + std::vector index_shapes() const { + std::vector shapes; + auto indices = index_operands(); + for (const HloInstruction* index : indices) { + shapes.push_back(index->shape()); + } + return shapes; + } +}; + +class HloDynamicSliceInstruction : public HloDynamicIndexInstruction { public: explicit HloDynamicSliceInstruction(const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, absl::Span slice_sizes); + explicit HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + absl::Span start_indices, + absl::Span slice_sizes); // Old methods kept for smooth subclassing transition END. // Returns the size of the slice in the given dimension for a dynamic // slice node. @@ -1189,6 +1227,8 @@ class HloDynamicSliceInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + int64 first_index_operand_number() const override { return 1; } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1206,6 +1246,19 @@ class HloDynamicSliceInstruction : public HloInstruction { std::vector dynamic_slice_sizes_; }; +class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction { + public: + explicit HloDynamicUpdateSliceInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices); + explicit HloDynamicUpdateSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* update, + absl::Span start_indices); + + int64 first_index_operand_number() const override { return 2; } +}; + class HloGatherInstruction : public HloInstruction { public: explicit HloGatherInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 1390537101e95a08e4ba4eef7ae8d6059a40e916..c1a642dfea7e464aaf93ffde1e26e07c1a4b73cd 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -17,8 +17,10 @@ limitations under the License. #include +#include "absl/strings/ascii.h" #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -36,8 +38,8 @@ constexpr int kError = -2; // [a-zA-Z0-9_.-] bool IsIdentifierChar(char c) { - return isalnum(static_cast(c)) || c == '-' || c == '.' || - c == '_'; + return absl::ascii_isalnum(static_cast(c)) || c == '-' || + c == '.' || c == '_'; } } // namespace @@ -82,15 +84,29 @@ tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( return tensorflow::RegexpStringPiece(begin, end - begin); } +TokKind HloLexer::LookAhead() { + if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) { + return GetKind(); + } + + const char* old_current_ptr = current_ptr_; + TokenState old_token_state = token_state_; + Lex(); + TokKind kind = GetKind(); + token_state_ = old_token_state; + current_ptr_ = old_current_ptr; + return kind; +} + TokKind HloLexer::LexToken() { while (true) { - token_start_ = current_ptr_; + token_state_.token_start = current_ptr_; int current_char = GetNextChar(); switch (current_char) { default: // [a-zA-Z_] - if (isalpha(static_cast(current_char)) || + if (absl::ascii_isalpha(static_cast(current_char)) || current_char == '_') { return LexIdentifier(); } @@ -125,6 +141,12 @@ TokKind HloLexer::LexToken() { return LexNumberOrPattern(); case '=': return TokKind::kEqual; + case '<': + if (current_char == '<' && PeekCurrentChar() == '=') { + current_ptr_++; + return TokKind::kLeq; + } + return TokKind::kError; case ',': return TokKind::kComma; case '%': @@ -190,6 +212,15 @@ TokKind HloLexer::LexToken() { // A lone '/' is an error. return TokKind::kError; } + case '.': + if (PeekCurrentChar() == '.') { + current_ptr_++; + if (PeekCurrentChar() == '.') { + current_ptr_++; + return TokKind::kDots; + } + } + return TokKind::kError; case '"': return LexString(); } @@ -206,43 +237,37 @@ TokKind HloLexer::LexToken() { // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexIdentifier() { - { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - // 'consumable' will be advanced iff its prefix matches the pattern. - static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,\s]*)\](?:(dense|sparse)?{([\d,\s]+)})?)"}; - if (RE2::Consume(&consumable, *shape_pattern)) { - auto status_or_shape = ShapeUtil::ParseShapeString( - StringPieceFromPointers(token_start_, consumable.begin())); - if (status_or_shape.ok()) { - // This is a shape string. - shape_val_ = status_or_shape.ValueOrDie(); - current_ptr_ = consumable.begin(); - return TokKind::kShape; - } - } - } - while (IsIdentifierChar(PeekCurrentChar())) { current_ptr_++; } // If followed by ':', it's a name. if (PeekCurrentChar() == ':') { - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); current_ptr_++; // skip ':' return TokKind::kName; } // If followed by '=', it's a attribute name. if (PeekCurrentChar() == '=') { - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); current_ptr_++; // skip '=' return TokKind::kAttributeName; } absl::string_view identifier = - StringPieceFromPointers(token_start_, current_ptr_); + StringPieceFromPointers(token_state_.token_start, current_ptr_); + + // Primitive type strings are reserved words. The exception is 'tuple' whose + // type is represented using nested parentheses without the string 'tuple'. + if (primitive_util::IsPrimitiveTypeName(identifier)) { + PrimitiveType primitive_type = + primitive_util::StringToPrimitiveType(identifier).ValueOrDie(); + if (primitive_type != TUPLE) { + token_state_.primitive_type_val = primitive_type; + return TokKind::kPrimitiveType; + } + } // See if this is a keyword. #define KEYWORD(STR) \ @@ -261,21 +286,23 @@ TokKind HloLexer::LexIdentifier() { KEYWORD(ROOT); KEYWORD(maximal); KEYWORD(replicated); + KEYWORD(sparse); #undef KEYWORD { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 dim_labels_pattern = { R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDimLabels; } } - str_val_ = string(identifier); + token_state_.str_val = string(identifier); return TokKind::kIdent; } @@ -283,13 +310,13 @@ TokKind HloLexer::LexIdentifier() { // name ::= [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexPercent() { const char* name_start = current_ptr_; - if (isalpha(static_cast(PeekCurrentChar())) || + if (absl::ascii_isalpha(static_cast(PeekCurrentChar())) || PeekCurrentChar() == '_') { current_ptr_++; while (IsIdentifierChar(PeekCurrentChar())) { current_ptr_++; } - str_val_.assign(name_start, current_ptr_); + token_state_.str_val.assign(name_start, current_ptr_); return TokKind::kName; } return TokKind::kError; @@ -307,12 +334,14 @@ TokKind HloLexer::LexPercent() { // int ::= [-]?[0-9]+ // negative inf ::= '-inf' TokKind HloLexer::LexNumberOrPattern() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 float_pattern = { R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); - CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_)); + CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_), + &token_state_.decimal_val)); return TokKind::kDecimal; } @@ -324,27 +353,28 @@ TokKind HloLexer::LexNumberOrPattern() { if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDimLabels; } if (RE2::Consume(&consumable, *dxd_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kDxD; } if (RE2::Consume(&consumable, *pad_pattern)) { current_ptr_ = consumable.begin(); - str_val_.assign(token_start_, current_ptr_); + token_state_.str_val.assign(token_state_.token_start, current_ptr_); return TokKind::kPad; } static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); - auto slice = StringPieceFromPointers(token_start_, current_ptr_); - if (absl::SimpleAtoi(slice, &int64_val_)) { + auto slice = + StringPieceFromPointers(token_state_.token_start, current_ptr_); + if (absl::SimpleAtoi(slice, &token_state_.int64_val)) { return TokKind::kInt; } LOG(ERROR) << "Failed to parse int literal: " << slice; @@ -403,16 +433,17 @@ absl::string_view HloLexer::GetLine(LocTy loc) const { } // Lexes quoted string with escaping characters. If matched, the quoted string -// will be unescaped and stored to str_val_. +// will be unescaped and stored to token_state_.str_val. TokKind HloLexer::LexString() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + auto consumable = + RegexpStringPieceFromPointers(token_state_.token_start, buf_.end()); static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); absl::string_view raw = - StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); + StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1); string error; - if (!absl::CUnescape(raw, &str_val_, &error)) { + if (!absl::CUnescape(raw, &token_state_.str_val, &error)) { LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; return TokKind::kError; } @@ -447,6 +478,8 @@ string TokKindToString(TokKind kind) { return "kRparen"; case TokKind::kArrow: return "kArrow"; + case TokKind::kLeq: + return "kLeq"; case TokKind::kw_HloModule: return "kw_HloModule"; case TokKind::kw_ENTRY: @@ -467,6 +500,10 @@ string TokKindToString(TokKind kind) { return "kw_inf"; case TokKind::kNegInf: return "kNegInf"; + case TokKind::kw_sparse: + return "kw_sparse"; + case TokKind::kPrimitiveType: + return "kPrimitiveType"; case TokKind::kName: return "kName"; case TokKind::kAttributeName: @@ -481,12 +518,12 @@ string TokKindToString(TokKind kind) { return "kIdent"; case TokKind::kString: return "kString"; - case TokKind::kShape: - return "kShape"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: return "kDecimal"; + case TokKind::kDots: + return "kDots"; } } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index d6a2b292a3916b2ff85f278cf5cb9f1567df88fa..16eed21617bc7254b67090d2b5acf9ccbd82f4ea 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -29,6 +28,59 @@ limitations under the License. namespace xla { +// Defines different kinds of tokens used by the HLO lexer. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + kDots, // ... + + kArrow, // -> + kLeq, // <= + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_ROOT, + kw_true, + kw_false, + kw_maximal, + kw_replicated, + kw_nan, + kw_inf, + kw_sparse, + + kNegInf, // -inf + + // Typed tokens. + kPrimitiveType, // F32, PRED, etc. + kName, // %foo + kAttributeName, // dimensions= + kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} + kDxD, // [0-9]+(x[0-9]+)+ + kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kIdent, // other identifiers + kString, // "abcd\"\n" + kInt, // 42 + kDecimal, // 4.2 +}; + +string TokKindToString(TokKind kind); + // Lexer for the HloModule::ToString() format text. // // This class is meant to be used by hlo_parser.cc. You shouldn't need to use @@ -39,9 +91,9 @@ class HloLexer { current_ptr_ = buf_.begin(); } - TokKind Lex() { return current_kind_ = LexToken(); } + TokKind Lex() { return token_state_.current_kind = LexToken(); } - TokKind GetKind() const { return current_kind_; } + TokKind GetKind() const { return token_state_.current_kind; } string GetStrVal() const { switch (GetKind()) { case TokKind::kName: @@ -51,28 +103,28 @@ class HloLexer { case TokKind::kPad: case TokKind::kString: case TokKind::kIdent: - return str_val_; + return token_state_.str_val; default: LOG(FATAL) << "This token does not have string value"; } } - Shape GetShapeVal() const { - CHECK(GetKind() == TokKind::kShape); - return shape_val_; - } tensorflow::int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); - return int64_val_; + return token_state_.int64_val; } double GetDecimalVal() const { CHECK(GetKind() == TokKind::kDecimal); - return decimal_val_; + return token_state_.decimal_val; + } + PrimitiveType GetPrimitiveTypeVal() const { + CHECK(GetKind() == TokKind::kPrimitiveType); + return token_state_.primitive_type_val; } typedef const char* LocTy; // Returns the location of the current token. - LocTy GetLoc() const { return token_start_; } + LocTy GetLoc() const { return token_state_.token_start; } // Returns the line and column of a location in the buffer. std::pair GetLineAndColumn(LocTy location) const; @@ -80,6 +132,9 @@ class HloLexer { // Returns the whole line given the location. absl::string_view GetLine(LocTy loc) const; + // Looks ahead one token and returns it. Lexer state is unchanged. + TokKind LookAhead(); + private: // Returns the current character. If it's neither the end of input buffer nor // an invalid character, moves the pointer forward. @@ -112,12 +167,15 @@ class HloLexer { const char* current_ptr_; // Information about the current token. - const char* token_start_ = nullptr; - TokKind current_kind_; - string str_val_; - Shape shape_val_; - tensorflow::int64 int64_val_; - double decimal_val_; + struct TokenState { + const char* token_start = nullptr; + TokKind current_kind; + string str_val; + tensorflow::int64 int64_val; + double decimal_val; + PrimitiveType primitive_type_val; + }; + TokenState token_state_; struct LineNoCacheTy { const char* last_query; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 5bf055f3c012fef687cdc275d62efdf2d4cd5e5c..e14bcfa7f67e736a4d04f5b236fb2df02cf150e0 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" @@ -36,11 +37,11 @@ namespace xla { namespace { using Worklist = std::deque; -using Workset = std::unordered_set; +using Workset = absl::flat_hash_set; void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, Workset* workset) { - if (workset->count(instruction) == 0) { + if (!workset->contains(instruction)) { worklist->push_back(instruction); workset->insert(instruction); VLOG(3) << "ADD instruction: " << instruction->name(); diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index e0ae1173c6114f0bc6ef18b2cfff9d54ccfe2faf..436cccb1fb9ecf6f4efad772c700c611b28ce628 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -403,9 +403,9 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) @@ -436,9 +436,9 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { HloModule OutfeedLoop InnerWhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 235efb19ce4ed28a5cd9fe5ca52ae5d8e9e5ba3d..67488a6a9a0c9cba7f576f9036c3a0cbe1900fff 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -178,7 +178,7 @@ HLO_MATCHER(Constant); HLO_MATCHER(Convert); HLO_MATCHER(Convolution); HLO_MATCHER(Copy); -HLO_MATCHER(CrossReplicaSum); +HLO_MATCHER(AllReduce); HLO_MATCHER(CollectivePermute); HLO_MATCHER(Divide); HLO_MATCHER(Domain); @@ -312,8 +312,8 @@ inline ::testing::Matcher Shape( } inline ::testing::Matcher Shape( absl::string_view shape) { - return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( - ShapeUtil::ParseShapeString(shape).ValueOrDie())); + return ::testing::MakeMatcher( + new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie())); } inline ::testing::Matcher ShapeWithLayout( const class Shape& shape) { @@ -323,7 +323,7 @@ inline ::testing::Matcher ShapeWithLayout( inline ::testing::Matcher ShapeWithLayout( absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( - ShapeUtil::ParseShapeString(shape).ValueOrDie())); + ParseShape(shape).ValueOrDie())); } // Verifies the value of the HloSharing against the provided sharding object. diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 7227bfb27c74758d2b79e404afc9eb97a1ca894d..76cc29cbb7848eb424d07abf11a95ffd59e9eed6 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -118,7 +118,7 @@ class HloTrivialScheduler : public HloModulePass { }; // A trivial pass which clears the schedule currently set on the -// HloModule. After this pass runs HloModudle::has_schedule will return false. +// HloModule. After this pass runs HloModule::has_schedule will return false. class HloDescheduler : public HloModulePass { public: HloDescheduler() = default; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index fe8371384c0fa3900a9022f101ff0b296439cf16..759ce6541ef144ad3f84bcb87ddabf507a034305 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -107,11 +107,10 @@ HloComputation* HloModule::AddEntryComputation( } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { - auto it = - std::find_if(computations_.begin(), computations_.end(), - [&to_remove](const std::unique_ptr& comp) { - return comp.get() == to_remove; - }); + auto it = absl::c_find_if( + computations_, [&to_remove](const std::unique_ptr& comp) { + return comp.get() == to_remove; + }); TF_RET_CHECK(it->get() == to_remove); computations_.erase(it); return Status::OK(); @@ -251,7 +250,7 @@ HloModuleProto HloModule::ToProto() const { StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { VLOG(2) << "CreateFromProto()"; - XLA_VLOG_LINES(2, proto.DebugString()); + XLA_VLOG_LINES(3, proto.DebugString()); // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. @@ -304,11 +303,10 @@ StatusOr> HloModule::CreateFromProto( auto module = absl::make_unique(proto.name(), module_config); // Sort the computations in the proto id's order. - std::sort(computations.begin(), computations.end(), - [&](const std::unique_ptr& a, - const std::unique_ptr& b) { - return to_proto_id[a.get()] < to_proto_id[b.get()]; - }); + absl::c_sort(computations, [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); // Add sorted computations to the module. for (auto& computation : computations) { @@ -392,15 +390,12 @@ namespace { // Returns whether `hlo` is used outside the given subcomputation. // `instructions_in_subcomputation` is the instruction set of the given // subcomputation. -bool IsUsedOutsideSubcomputation( - const HloInstruction& hlo, - const std::unordered_set& instructions_in_subcomputation) { - for (HloInstruction* user : hlo.users()) { - if (!instructions_in_subcomputation.count(user)) { - return true; - } - } - return false; +bool IsUsedOutsideSubcomputation(const HloInstruction& hlo, + const absl::flat_hash_set& + instructions_in_subcomputation) { + return absl::c_any_of(hlo.users(), [&](HloInstruction* user) { + return !instructions_in_subcomputation.contains(user); + }); } } // anonymous namespace @@ -411,9 +406,9 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( // A map from original instructions to their counterparts in the new outlined // function. - std::unordered_map outlined_instructions; + absl::flat_hash_map outlined_instructions; // A set that contains all instructions to be outlined. - std::unordered_set instruction_set_to_outline( + absl::flat_hash_set instruction_set_to_outline( instructions_to_outline.begin(), instructions_to_outline.end()); std::vector arguments; std::vector outputs; @@ -502,7 +497,7 @@ std::vector HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the // module). - std::set nonroot_computations; + absl::flat_hash_set nonroot_computations; for (auto& computation : computations_) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -515,19 +510,19 @@ std::vector HloModule::MakeComputationPostOrder() const { // Keep track of computations which have already been added to the post // order. This prevents duplication as an embedded computation may be called // from two different root computations. - std::set added_computations; + absl::flat_hash_set added_computations; std::vector post_order; for (auto& computation : computations_) { - if (nonroot_computations.count(computation.get()) == 0) { + if (!nonroot_computations.contains(computation.get())) { for (HloComputation* embedded_computation : computation->MakeEmbeddedComputationsList()) { - if (added_computations.count(embedded_computation) == 0) { + if (!added_computations.contains(embedded_computation)) { post_order.push_back(embedded_computation); added_computations.insert(embedded_computation); } } // Root computations should only be encountered once. - CHECK_EQ(0, added_computations.count(computation.get())); + CHECK(!added_computations.contains(computation.get())); post_order.push_back(computation.get()); added_computations.insert(computation.get()); } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 7b9cbf9a53a2201b1312405bbd7ed2b88f65c9be..f1310e4b270898a21dbb4f86123edde4ba8993d0 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -136,7 +136,9 @@ class HloModule { // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO modules, // with respect to HloInstruction::Identical() method. - uint64 Hash() const { return entry_computation()->Hash(); } + uint64 Hash() const { + return entry_computation()->root_instruction()->Hash(); + } // Gets the computations in this module. // diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc index 31d26cc51e8217234526bbfeb83510aadf2c27b5..6b72ba128664d27c51aa8dcfa61fe959a0160c73 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -49,7 +49,7 @@ StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { auto* while_body_param = while_body_comp->parameter_instruction(0); auto* while_body_root = while_body_comp->root_instruction(); - if (!ShapeUtil::IsTuple(xla_while->shape()) || + if (!xla_while->shape().IsTuple() || while_body_root->opcode() != HloOpcode::kTuple) { // Only run DCE on tuple-shaped while loops where body root is Tuple, // with no I/O instructions. diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index bf66cc6bc37a5e11c9ecfc07a62ba0ea5ca11a03..f6e2866204955ac024c2b6f972de449cc3df4c15 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -38,9 +38,7 @@ class HloModuleDceTest : public HloTestBase { // Returns whether the given instruction exists in the given computation. bool HasInstruction(const HloComputation& computation, const HloInstruction* instruction) { - return std::find(computation.instructions().begin(), - computation.instructions().end(), - instruction) != computation.instructions().end(); + return absl::c_linear_search(computation.instructions(), instruction); } // Returns whether the while instruction with name 'while_name' in @@ -373,9 +371,9 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) { HloModule OutfeedLoop WhileBody { body_param = (s32[]) parameter(0) - token = token[] after-all() + token0 = token[] after-all() constant.2 = s32[] constant(2) - outfeed_tuple = (s32[]) outfeed(constant.2, token) + outfeed_tuple = (s32[]) outfeed(constant.2, token0) get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index b4aac4c8076cb69647d42c6243bc969d06d0709e..47734bc55cc00d605f4e318400be88639450343c 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -79,36 +79,36 @@ Status HloModuleGroupMetadata::Build() { return Status::OK(); } - std::vector peers; - if (IsChannelInstruction(hlo)) { - peers.push_back(PeerComputation(hlo)); - } else if (hlo->IsCrossModuleAllReduce()) { - for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { - if (instr == hlo) { - continue; + if (IsChannelInstruction(hlo) || hlo->IsCrossModuleAllReduce()) { + std::vector peers; + if (IsChannelInstruction(hlo)) { + peers.push_back(PeerComputation(hlo)); + } else if (hlo->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { + if (instr == hlo) { + continue; + } + peers.push_back(instr->parent()); } - peers.push_back(instr->parent()); } - } - - // Add the parent computation of this channel (or all-reduce) instruction - // and its peer computation(s) (both must be while computations) as - // companions. - for (HloComputation* peer_computation : peers) { - const TrackedInstruction* peer_tracked = - GetTrackedInstruction(peer_computation); - TF_RET_CHECK(peer_tracked != nullptr) - << "Peer instruction is not a possible companion"; - TF_RET_CHECK(*tracked == *peer_tracked) - << "Peer instruction does not match the computation kind"; - TF_RETURN_IF_ERROR( - AddCompanion(tracked->instruction(), peer_tracked->instruction())); - tracked_instructions_comms_[tracked->instruction()].push_back(hlo); - } - // Add the parents of companion instructions (they must be all of the same - // kind of instructions, opcode wise) as companions. - if (IsCompanionInstruction(hlo)) { + // Add the parent computation of this channel (or all-reduce) instruction + // and its peer computation(s) (both must be while computations) as + // companions. + for (HloComputation* peer_computation : peers) { + const TrackedInstruction* peer_tracked = + GetTrackedInstruction(peer_computation); + TF_RET_CHECK(peer_tracked != nullptr) + << "Peer instruction is not a possible companion"; + TF_RET_CHECK(*tracked == *peer_tracked) + << "Peer instruction does not match the computation kind"; + TF_RETURN_IF_ERROR( + AddCompanion(tracked->instruction(), peer_tracked->instruction())); + tracked_instructions_comms_[tracked->instruction()].push_back(hlo); + } + } else if (IsCompanionInstruction(hlo)) { + // Add the parents of companion instructions (they must be all of the same + // kind of instructions, opcode wise) as companions. for (HloInstruction* companion : Companions(hlo)) { const TrackedInstruction* companion_tracked = GetTrackedInstruction(companion->parent()); @@ -118,6 +118,7 @@ Status HloModuleGroupMetadata::Build() { companion_tracked->instruction())); } } + return Status::OK(); }; @@ -198,7 +199,7 @@ bool HloModuleGroupMetadata::IsChannelInstruction( } bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const { - return companion_set_index_.count(hlo) > 0; + return companion_set_index_.contains(hlo); } bool HloModuleGroupMetadata::InstructionCommunicates( @@ -509,7 +510,7 @@ Status HloModuleGroupMetadata::CheckCommunicatingInstruction( HloComputation* computation = instruction->parent(); const HloModule* module = computation->parent(); if (module->entry_computation() == computation || - tracked_instructions_.count(computation) > 0) { + tracked_instructions_.contains(computation)) { return Status::OK(); } return FailedPrecondition("channel is used in disallowed computation"); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 928df0f5a7444ad877961a5de970c752e1d024da..3ed95c10504141139d83eb8679a0b8144b15ad0d 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -38,7 +38,7 @@ namespace xla { // Class for bookkeeping the information on the given modules, in particular on // the interaction between computations. // -// Companion instructions are one of the information collected as we build the +// Companion instructions are one piece of information collected as we build the // metadata. For example, for each While instruction, companion instructions // refer to a set of While instructions in other computations that communicate // with each other. @@ -51,6 +51,13 @@ namespace xla { // } While_4() { Recv(0) } // } // +// Each instruction can belong to at most one companion set: While_0 and While_5 +// are in the same set even though they don't communicate with each other, +// because they both communicate with While_2. +// +// A send and the matching recv must both have the same level of nesting of +// companion instructions. +// // Companion instructions are used to detect cycles in the graph and also for // global scheduling. class HloModuleGroupMetadata { @@ -171,7 +178,7 @@ class HloModuleGroupMetadata { // Precondition: IsCompanionWhile(instruction) is true. const std::vector& Companions( const HloInstruction* instruction) const { - CHECK_EQ(companion_set_index_.count(instruction), 1); + CHECK(companion_set_index_.contains(instruction)); return companion_set(companion_set_index_.at(instruction)); } @@ -215,11 +222,8 @@ class HloModuleGroupMetadata { // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone). // * The shape of channel instructions match. // * The nest level of channel instructions match. - // * Channel instructions are used in allowed computations; i.e., in the + // * Channel instructions are used in allowed computations, i.e., in the // entry computation of the module or condition/body of While computations. - // - // TODO(b/62064342): Currently, HloModuleGroupScheduler checks if there is a - // cycle in the graph, but it would be good to verify here. Status VerifyChannelInstructions(); // Adds metadata that the given two instructions are companions. @@ -231,8 +235,8 @@ class HloModuleGroupMetadata { Status CheckCommunicatingInstruction(HloInstruction* instruction) const; // Performs a consistency check on the companion sets built for the input - // modules. Check that a companion set does not include instructions from the - // same module/device. + // modules. Checks that each instruction in a companion set is in a different + // module/device. Status VerifyCompanionSets() const; // Retrieves a pointer to the stored TrackedInstruction associated with a diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index fddeb5f0a27a43ff9ca8b2b5d314bcfe91aaf0e6..91417bd2d9a6ca8a5192a37302e6a91e49a94d77 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -198,6 +198,8 @@ std::vector HloModuleGroupUtil::RootInstructions( for (HloComputation* computation : computations) { for (HloInstruction* instruction : computation->instructions()) { if (GlobalSuccessors(instruction).empty()) { + // An instruction that has no successors, e.g., an unused instruction, + // is in roots, even though it's not the ROOT of its computation. roots.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index f21b44bcd98d77b831de5d8a6afa4f9ddd91d15d..862666b48c9aa423ba4eeea3052c17fcc1064fd2 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -49,7 +49,7 @@ class HloModuleGroupUtil { // Returns all unique successors of the instruction. This includes: // * successors in the same computation: users and control successors // * Send is a successor of Recv - // * RecvDone is a predecessor of Send + // * RecvDone is a successor of Send // * successors of companions (if the instruction is a companion while) // * successors' companions (for any successor that is a companion while) std::vector GlobalSuccessors(HloInstruction* instruction); diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 127cfd165a5d8229cac3035f56a66f1bcfa734f3..bf9b3c811704870d9e0a36de5c38a013fba6dfe4 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -49,6 +49,7 @@ namespace xla { V(kAdd, "add") \ V(kAddDependency, "add-dependency") \ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ + V(kAllReduce, "all-reduce") \ V(kAllToAll, "all-to-all") \ V(kAtan2, "atan2") \ V(kBatchNormGrad, "batch-norm-grad") \ @@ -70,7 +71,6 @@ namespace xla { V(kConvolution, "convolution") \ V(kCopy, "copy") \ V(kCos, "cosine") \ - V(kCrossReplicaSum, "cross-replica-sum") \ V(kCustomCall, "custom-call") \ V(kDivide, "divide") \ V(kDomain, "domain") \ @@ -139,7 +139,8 @@ namespace xla { V(kTranspose, "transpose") \ V(kTuple, "tuple", kHloOpcodeIsVariadic) \ V(kTupleSelect, "tuple-select") \ - V(kWhile, "while") + V(kWhile, "while") \ + V(kReplicaId, "replica-id") enum class HloOpcode { #define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name, diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index ca6a154809be46d6a0305c29e2b89219de408019..0cec61c257bb84e467290fb52ec9063a32ed558d 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -367,7 +367,7 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const HloInstruction* a, const HloInstruction* b) const { CHECK_EQ(a->parent(), b->parent()); // If either instruction is not in the order, then 'a' and 'b' are unordered. - if (order_position_.count(a) == 0 || order_position_.count(b) == 0) { + if (!order_position_.contains(a) || !order_position_.contains(b)) { return false; } return order_position_.at(a) < order_position_.at(b); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 9b5bb5d0bd6af104ef62eaa5d3e53cedbe0213d3..dd8e8ff3a52a1cdf99a2b07b83e2891f90cf85bb 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include #include "absl/algorithm/container.h" #include "absl/memory/memory.h" @@ -21,10 +22,13 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/types/variant.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" @@ -43,8 +47,6 @@ using absl::StrCat; using absl::StrFormat; using absl::StrJoin; -const double kF16max = 65504; - // Creates and returns a schedule created using the order of the instructions in // the HloComputation::instructions() vectors in the module. HloSchedule ScheduleFromInstructionOrder(HloModule* module) { @@ -59,6 +61,11 @@ HloSchedule ScheduleFromInstructionOrder(HloModule* module) { return schedule; } +// Some functions accept either a linear index or a multi-dimensional index +// (used for indexing into sparse literals). +using LinearOrMultiIndex = + absl::variant>; + // Parser for the HloModule::ToString() format text. class HloParser { public: @@ -74,6 +81,7 @@ class HloParser { string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. + StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); @@ -100,7 +108,7 @@ class HloParser { // Parse a single instruction worth of text. bool ParseSingleInstruction(HloModule* module); - // ParseXXX returns false if an error occurred. + // Parses a module, returning false if an error occurred. bool ParseHloModule(HloModule* module); bool ParseComputations(HloModule* module); @@ -116,21 +124,30 @@ class HloParser { bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); bool ParseDenseLiteral(Literal* literal, const Shape& shape); bool ParseSparseLiteral(Literal* literal, const Shape& shape); - template - bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape); - // Sets the sub-value of literal at the given index to the given value. The - // literal's shape must have the default layout. - bool SetValueInLiteral(tensorflow::int64 value, - tensorflow::int64 linear_index, Literal* literal); - bool SetValueInLiteral(double value, tensorflow::int64 linear_index, + // Sets the sub-value of literal at the given linear or sparse index to the + // given value. If the literal is dense, it myst have the default layout. + // + // `loc` should be the source location of the value. + bool SetValueInLiteral(LocTy loc, tensorflow::int64 value, + LinearOrMultiIndex index, Literal* literal); + bool SetValueInLiteral(LocTy loc, double value, LinearOrMultiIndex index, Literal* literal); - bool SetValueInLiteral(bool value, tensorflow::int64 linear_index, + bool SetValueInLiteral(LocTy loc, bool value, LinearOrMultiIndex index, Literal* literal); + bool SetValueInLiteral(LocTy loc, std::complex value, + LinearOrMultiIndex index, Literal* literal); + // `loc` should be the source location of the value. template - bool SetValueInLiteralHelper(ParsedElemT value, - tensorflow::int64 linear_index, - Literal* literal); + bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value, + LinearOrMultiIndex index, Literal* literal); + + // Checks whether the given value is within the range of LiteralNativeT. + // `loc` should be the source location of the value. + template + bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value); + template + bool CheckParsedValueIsInRange(LocTy loc, std::complex value); bool ParseOperands(std::vector* operands); // Fills parsed operands into 'operands' and expects a certain number of @@ -255,7 +272,10 @@ class HloParser { bool ParseName(string* result); bool ParseAttributeName(string* result); bool ParseString(string* result); + bool ParseDimensionSizes(std::vector* dimension_sizes, + std::vector* dynamic_dimensions); bool ParseShape(Shape* result); + bool ParseLayout(Layout* layout); bool ParseOpcode(HloOpcode* result); bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); @@ -263,6 +283,7 @@ class HloParser { bool ParsePrecision(PrecisionConfig::Precision* result); bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); + bool ParseComplex(std::complex* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); @@ -279,9 +300,6 @@ class HloParser { // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. bool EatIfPresent(TokKind kind); - // Parses a shape, and returns true if the result is compatible with the given - // shape. - bool EatShapeAndCheckCompatible(const Shape& shape); // Adds the instruction to the pool. Returns false and emits an error if the // instruction already exists. @@ -641,8 +659,14 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, tensorflow::int64 parameter_number; if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || - !ParseInt64(¶meter_number) || - !ParseToken(TokKind::kRparen, "expects ')' after parameter number") || + !ParseInt64(¶meter_number)) { + return false; + } + if (parameter_number < 0) { + Error(lexer_.GetLoc(), "parameter number must be >= 0"); + return false; + } + if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") || !ParseAttributes(attrs)) { return false; } @@ -766,7 +790,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction::CreateBitcastConvert(shape, operands[0])); break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { optional>> tmp_groups; optional to_apply; optional> replica_group_ids; @@ -786,10 +810,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, if (tmp_groups) { replica_groups = CreateReplicaGroups(*tmp_groups); } - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, replica_groups, - barrier ? *barrier : "", all_reduce_id)); + instruction = builder->AddInstruction(HloInstruction::CreateAllReduce( + shape, operands, *to_apply, replica_groups, barrier ? *barrier : "", + all_reduce_id)); break; } case HloOpcode::kAllToAll: { @@ -829,6 +852,17 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction::CreateCollectivePermute(shape, operands[0], pairs)); break; } + case HloOpcode::kReplicaId: { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + if (!operands.empty()) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateReplicaId()); + break; + } case HloOpcode::kReshape: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -1006,11 +1040,14 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional window; optional dnums; optional feature_group_count; + optional batch_group_count; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/true, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64, + &batch_group_count}; optional> operand_precision; attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, &operand_precision}; @@ -1024,6 +1061,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, if (!feature_group_count) { feature_group_count = 1; } + if (!batch_group_count) { + batch_group_count = 1; + } PrecisionConfig precision_config; if (operand_precision) { *precision_config.mutable_operand_precision() = { @@ -1034,7 +1074,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( shape, /*lhs=*/operands[0], /*rhs=*/operands[1], - feature_group_count.value(), *window, *dnums, precision_config)); + feature_group_count.value(), batch_group_count.value(), *window, + *dnums, precision_config)); break; } case HloOpcode::kFft: { @@ -1163,24 +1204,39 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { + LocTy loc = lexer_.GetLoc(); + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.empty()) { + return Error(loc, "Expected at least one operand."); + } + if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) && + operands.size() != 1 + operands[0]->shape().rank()) { + return Error(loc, "Wrong number of operands."); + } instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice( - shape, /*operand=*/operands[0], /*start_indices=*/operands[1], + shape, /*operand=*/operands[0], + /*start_indices=*/absl::MakeSpan(operands).subspan(1), *dynamic_slice_sizes)); break; } case HloOpcode::kDynamicUpdateSlice: { - if (!ParseOperands(&operands, /*expected_size=*/3) || - !ParseAttributes(attrs)) { + LocTy loc = lexer_.GetLoc(); + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.size() < 2) { + return Error(loc, "Expected at least two operands."); + } + if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) && + operands.size() != 2 + operands[0]->shape().rank()) { + return Error(loc, "Wrong number of operands."); + } instruction = builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, /*operand=*/operands[0], /*update=*/operands[1], - /*start_indices=*/operands[2])); + /*start_indices=*/absl::MakeSpan(operands).subspan(2))); break; } case HloOpcode::kTranspose: { @@ -1280,7 +1336,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail // if the shape is not a non-empty tuple, so add guard so an error message // can be emitted instead of a check fail - if (!ShapeUtil::IsTuple(shape) && !ShapeUtil::IsEmptyTuple(shape)) { + if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) { return Error(lexer_.GetLoc(), "infeed must have a non-empty tuple shape"); } @@ -1352,6 +1408,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional window; optional dnums; optional feature_group_count; + optional batch_group_count; optional> operand_layout_constraints; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; @@ -1361,6 +1418,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64, + &batch_group_count}; attrs["operand_layout_constraints"] = { /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { @@ -1416,6 +1475,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, if (feature_group_count.has_value()) { instruction->set_feature_group_count(*feature_group_count); } + if (batch_group_count.has_value()) { + instruction->set_batch_group_count(*batch_group_count); + } break; } case HloOpcode::kDot: { @@ -1697,11 +1759,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } break; } - case TokKind::kShape: - // TODO(b/112302613): Left here for backward compatibility to ignore the - // removed tile shape data. - lexer_.Lex(); - break; case TokKind::kRbrace: break; default: @@ -1798,142 +1855,149 @@ bool HloParser::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParser::SetValueInLiteral(tensorflow::int64 value, - tensorflow::int64 linear_index, - Literal* literal) { +bool HloParser::SetValueInLiteral(LocTy loc, tensorflow::int64 value, + LinearOrMultiIndex index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case S8: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case S16: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case S32: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case S64: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case U8: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case U16: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case U32: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case U64: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case PRED: // Bool type literals with rank >= 1 are printed in 0s and 1s. - return SetValueInLiteralHelper(static_cast(value), - linear_index, literal); + return SetValueInLiteralHelper(loc, static_cast(value), index, + literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index, - Literal* literal) { +bool HloParser::SetValueInLiteral(LocTy loc, double value, + LinearOrMultiIndex index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case F16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); case BF16: - return SetValueInLiteralHelper(value, linear_index, + return SetValueInLiteralHelper(loc, value, index, literal); case F32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); case F64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); default: LOG(FATAL) << "unknown floating point primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index, - Literal* literal) { +bool HloParser::SetValueInLiteral(LocTy loc, bool value, + LinearOrMultiIndex index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case PRED: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(loc, value, index, literal); default: LOG(FATAL) << PrimitiveType_Name(shape.element_type()) << " is not PRED type"; } } +bool HloParser::SetValueInLiteral(LocTy loc, std::complex value, + LinearOrMultiIndex index, Literal* literal) { + const Shape& shape = literal->shape(); + switch (shape.element_type()) { + case C64: + return SetValueInLiteralHelper>(loc, value, index, + literal); + case C128: + return SetValueInLiteralHelper>(loc, value, index, + literal); + default: + LOG(FATAL) << PrimitiveType_Name(shape.element_type()) + << " is not a complex type type"; + } +} + +template +string StringifyValue(T val) { + return StrCat(val); +} +template <> +string StringifyValue(std::complex val) { + return StrFormat("(%f, %f)", std::real(val), std::imag(val)); +} + template -bool HloParser::SetValueInLiteralHelper(ParsedElemT value, - tensorflow::int64 linear_index, +bool HloParser::SetValueInLiteralHelper(LocTy loc, ParsedElemT value, + LinearOrMultiIndex index, Literal* literal) { - // Check that linear_index is in range. - if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { - return TokenError( - StrCat("trys to set value ", value, " to a literal in shape ", - ShapeUtil::HumanString(literal->shape()), " at linear index ", - linear_index, ", but the index is out of range")); + if (!CheckParsedValueIsInRange(loc, value)) { + return false; } - if (std::isnan(value) || - (std::numeric_limits::has_infinity && - (std::numeric_limits::infinity() == value || - -std::numeric_limits::infinity() == value))) { - // Skip range checking for non-finite value. - } else if (literal->shape().element_type() == F16 || - literal->shape().element_type() == BF16) { - if (value > kF16max || value < -kF16max) { - return TokenError(StrCat( - "value ", value, " is out of range for literal's primitive type ", - PrimitiveType_Name(literal->shape().element_type()))); - } - } else if (std::is_unsigned::value) { - CHECK((std::is_same::value || - std::is_same::value)) - << "Unimplemented checking for ParsedElemT"; - - ParsedElemT upper_bound; - if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { - upper_bound = std::numeric_limits::max(); - } else { - upper_bound = - static_cast(std::numeric_limits::max()); + // Check that the index is in range and assign into the literal + if (auto* linear_index = absl::get_if(&index)) { + if (*linear_index >= ShapeUtil::ElementsIn(literal->shape())) { + return Error(loc, StrCat("trys to set value ", StringifyValue(value), + " to a literal in shape ", + ShapeUtil::HumanString(literal->shape()), + " at linear index ", *linear_index, + ", but the index is out of range")); } - if (value > upper_bound || value < 0) { - // Value is out of range for LiteralNativeT. - return TokenError(StrCat( - "value ", value, " is out of range for literal's primitive type ", - PrimitiveType_Name(literal->shape().element_type()))); - } - } else if (value > static_cast( - std::numeric_limits::max()) || - value < static_cast( - std::numeric_limits::lowest())) { - // Value is out of range for LiteralNativeT. - return TokenError(StrCat( - "value ", value, " is out of range for literal's primitive type ", - PrimitiveType_Name(literal->shape().element_type()))); - } + literal->data().at(*linear_index) = + static_cast(value); + } else { + auto* multi_index = absl::get_if>(&index); + CHECK(multi_index != nullptr); - literal->data().at(linear_index) = - static_cast(value); - return true; -} + auto invalid_idx = [&](string msg) { + return Error(loc, StrFormat("Invalid sparse index [%s]. %s", + absl::StrJoin(*multi_index, ", "), msg)); + }; -bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { - Shape new_shape; - if (!ParseShape(&new_shape)) { - return TokenError(StrCat("expects shape ", ShapeUtil::HumanString(shape))); - } - if (!ShapeUtil::Compatible(shape, new_shape)) { - return TokenError(StrCat( - "expects shape ", ShapeUtil::HumanString(shape), - ", but sees a different shape: ", ShapeUtil::HumanString(new_shape))); + const auto& shape = literal->shape(); + if (shape.rank() != multi_index->size()) { + return invalid_idx( + StrFormat("Has rank %d, but constant has shape %s, which has rank %d", + multi_index->size(), shape.ToString(), shape.rank())); + } + for (int64 i = 0; i < shape.rank(); ++i) { + auto idx = (*multi_index)[i]; + if (idx < 0) { + return invalid_idx(StrFormat( + "Sub-index value at %d, namely %d, cannot be negative.", i, idx)); + } + if (idx >= shape.dimensions(i)) { + return invalid_idx( + StrFormat("Sub-index at %d, namely %d, doesn't fit within shape " + "dimension %d in %s", + i, idx, shape.dimensions(i), shape.ToString())); + } + } + literal->AppendSparseElement(*multi_index, + static_cast(value)); } return true; } @@ -1942,8 +2006,8 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { // ::= tuple // ::= non_tuple bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { - return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) - : ParseNonTupleLiteral(literal, shape); + return shape.IsTuple() ? ParseTupleLiteral(literal, shape) + : ParseNonTupleLiteral(literal, shape); } // tuple @@ -1952,10 +2016,6 @@ bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { // ::= /*empty*/ // ::= literal (',' literal)* bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { - if (!EatShapeAndCheckCompatible(shape)) { - return TokenError(StrCat("expects tuple constant in shape ", - ShapeUtil::HumanString(shape))); - } if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } @@ -1990,15 +2050,15 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { return ParseSparseLiteral(literal, shape); } - CHECK(LayoutUtil::IsDenseArray(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true); return ParseDenseLiteral(literal, shape); } bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { - const tensorflow::int64 rank = ShapeUtil::Rank(shape); - if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { - return false; - } + // Cast `rank` to int because we call shape.dimensions(int rank) below, and if + // `rank` is an int64, that's an implicit narrowing conversion, which is + // implementation-defined behavior. + const int rank = static_cast(shape.rank()); // Create a literal with the given shape in default layout. *literal = LiteralUtil::CreateFromDimensions( @@ -2023,6 +2083,24 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { }), "]"); }; + + auto add_one_elem_seen = [&] { + if (rank > 0) { + if (nest_level != rank) { + return TokenError(absl::StrFormat( + "expects nested array in rank %d, but sees %d", rank, nest_level)); + } + elems_seen_per_dim[rank - 1]++; + if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { + return TokenError(absl::StrFormat( + "expects %d elements on the minor-most dimension, but " + "sees more", + shape.dimensions(rank - 1))); + } + } + return true; + }; + do { switch (lexer_.GetKind()) { default: @@ -2058,6 +2136,31 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { lexer_.Lex(); break; } + case TokKind::kLparen: { + if (!primitive_util::IsComplexType(shape.element_type())) { + return TokenError( + absl::StrFormat("unexpected '(' in literal. Parens are only " + "valid for complex literals")); + } + + std::complex value; + LocTy loc = lexer_.GetLoc(); + if (!add_one_elem_seen() || !ParseComplex(&value) || + !SetValueInLiteral(loc, value, linear_index++, literal)) { + return false; + } + break; + } + case TokKind::kDots: { + if (nest_level != 1) { + return TokenError(absl::StrFormat( + "expects `...` at nest level 1, but sees it at nest level %d", + nest_level)); + } + elems_seen_per_dim[0] = shape.dimensions(0); + lexer_.Lex(); + break; + } case TokKind::kComma: // Skip. lexer_.Lex(); @@ -2069,23 +2172,11 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { case TokKind::kw_nan: case TokKind::kw_inf: case TokKind::kNegInf: { - if (rank > 0) { - if (nest_level != rank) { - return TokenError( - absl::StrFormat("expects nested array in rank %d, but sees %d", - rank, nest_level)); - } - elems_seen_per_dim[rank - 1]++; - if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { - return TokenError(absl::StrFormat( - "expects %d elements on the minor-most dimension, but " - "sees more", - shape.dimensions(rank - 1))); - } - } + add_one_elem_seen(); if (lexer_.GetKind() == TokKind::kw_true || lexer_.GetKind() == TokKind::kw_false) { - if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, + if (!SetValueInLiteral(lexer_.GetLoc(), + lexer_.GetKind() == TokKind::kw_true, linear_index++, literal)) { return false; } @@ -2098,7 +2189,7 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal)) { + if (!SetValueInLiteral(loc, value, linear_index++, literal)) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { @@ -2109,7 +2200,7 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { loc, StrCat("expect floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal)) { + if (!SetValueInLiteral(loc, value, linear_index++, literal)) { return false; } } else { @@ -2126,52 +2217,7 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { - if (!EatShapeAndCheckCompatible(shape)) { - return false; - } - - switch (shape.element_type()) { - case PRED: - return ParseSparseLiteralHelper(literal, shape); - case S8: - return ParseSparseLiteralHelper(literal, shape); - case S16: - return ParseSparseLiteralHelper(literal, shape); - case S32: - return ParseSparseLiteralHelper(literal, shape); - case S64: - return ParseSparseLiteralHelper(literal, shape); - case U8: - return ParseSparseLiteralHelper(literal, shape); - case U16: - return ParseSparseLiteralHelper(literal, shape); - case U32: - return ParseSparseLiteralHelper(literal, shape); - case U64: - return ParseSparseLiteralHelper(literal, shape); - case F16: - return ParseSparseLiteralHelper(literal, shape); - case F32: - return ParseSparseLiteralHelper(literal, shape); - case BF16: - return ParseSparseLiteralHelper(literal, shape); - case F64: - return ParseSparseLiteralHelper(literal, shape); - default: - return Error(lexer_.GetLoc(), - StrCat("invalid primitive type for sparse literal: ", - PrimitiveType_Name(shape.element_type()))); - } -} - -template -bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { - std::vector index; - - tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = Literal(shape); - if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { return false; @@ -2183,61 +2229,66 @@ bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { break; } - LocTy index_loc = lexer_.GetLoc(); - index.clear(); + std::vector index; if (lexer_.GetKind() == TokKind::kInt) { tensorflow::int64 single_index = lexer_.GetInt64Val(); lexer_.Lex(); - if (rank != 1) { - return Error( - index_loc, - StrCat("invalid single-dimensional index for shape with rank ", - rank, ": ", single_index)); - } index.push_back(single_index); } else { if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, &index)) { return false; } - if (index.size() != rank) { - return Error( - index_loc, - StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", StrJoin(index, ", "), "]")); - } } if (!ParseToken(TokKind::kColon, "expects ':' after after the sparse array index and before " "the sparse array value")) { return false; } + LocTy value_loc = lexer_.GetLoc(); - LiteralNativeT value; if (lexer_.GetKind() == TokKind::kw_true || lexer_.GetKind() == TokKind::kw_false) { - value = static_cast(lexer_.GetKind() == TokKind::kw_true); + bool value = lexer_.GetKind() == TokKind::kw_true; + if (!SetValueInLiteral(lexer_.GetLoc(), value, index, literal)) { + return false; + } lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { - tensorflow::int64 value_s64; - if (!ParseInt64(&value_s64)) { + tensorflow::int64 value; + if (!ParseInt64(&value)) { return Error(value_loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - value = static_cast(value_s64); + if (!SetValueInLiteral(value_loc, value, index, literal)) { + return false; + } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { - double value_f64; - if (!ParseDouble(&value_f64)) { + double value; + if (!ParseDouble(&value)) { return Error(value_loc, StrCat("expects floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - value = static_cast(value_f64); + if (!SetValueInLiteral(value_loc, value, index, literal)) { + return false; + } + } else if (primitive_util::IsComplexType(shape.element_type())) { + std::complex value; + if (!ParseComplex(&value)) { + return Error(value_loc, + StrCat("expects complex value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + if (!SetValueInLiteral(value_loc, value, index, literal)) { + return false; + } } else { LOG(FATAL) << "Unexpected element type: " << PrimitiveType_Name(shape.element_type()); } + if (lexer_.GetKind() != TokKind::kRbrace && !ParseToken(TokKind::kComma, "expects ',' separator between sparse array elements")) { @@ -2251,14 +2302,114 @@ bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { StrCat("number of sparse elements exceeds maximum for layout: ", ShapeUtil::HumanStringWithLayout(shape))); } - - literal->AppendSparseElement(index, value); } literal->SortSparseElements(); return true; } +// MaxFiniteValue is a type-traits helper used by +// HloParser::CheckParsedValueIsInRange. +template +struct MinMaxFiniteValue { + static T max() { return std::numeric_limits::max(); } + static T min() { return std::numeric_limits::lowest(); } +}; + +template <> +struct MinMaxFiniteValue { + static double max() { + // Sadly this is not constexpr, so this forces `value` to be a method. + return static_cast(Eigen::NumTraits::highest()); + } + static double min() { return -max(); } +}; + +template <> +struct MinMaxFiniteValue { + static double max() { return static_cast(bfloat16::highest()); } + static double min() { return -max(); } +}; + +template +bool HloParser::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) { + PrimitiveType literal_ty = + primitive_util::NativeToPrimitiveType(); + if (std::isnan(value) || + (std::numeric_limits::has_infinity && + (std::numeric_limits::infinity() == value || + -std::numeric_limits::infinity() == value))) { + // Skip range checking for non-finite value. + } else if (std::is_unsigned::value) { + CHECK((std::is_same::value || + std::is_same::value)) + << "Unimplemented checking for ParsedElemT"; + + ParsedElemT upper_bound; + if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { + upper_bound = std::numeric_limits::max(); + } else { + upper_bound = + static_cast(std::numeric_limits::max()); + } + if (value > upper_bound || value < 0) { + // Value is out of range for LiteralNativeT. + return Error(loc, StrCat("value ", value, + " is out of range for literal's primitive type ", + PrimitiveType_Name(literal_ty), " namely [0, ", + upper_bound, "].")); + } + } else if (value > MinMaxFiniteValue::max() || + value < MinMaxFiniteValue::min()) { + // Value is out of range for LiteralNativeT. + return Error(loc, StrCat("value ", value, + " is out of range for literal's primitive type ", + PrimitiveType_Name(literal_ty), " namely [", + MinMaxFiniteValue::min(), ", ", + MinMaxFiniteValue::max(), "].")); + } + return true; +} + +template +bool HloParser::CheckParsedValueIsInRange(LocTy loc, + std::complex value) { + // e.g. `float` for std::complex + using LiteralComplexComponentT = + decltype(std::real(std::declval())); + + // We could do simply + // + // return CheckParsedValueIsInRange(std::real(value)) && + // CheckParsedValueIsInRange(std::imag(value)); + // + // but this would give bad error messages on failure. + + auto check_component = [&](absl::string_view name, double v) { + if (std::isnan(v) || v == std::numeric_limits::infinity() || + v == -std::numeric_limits::infinity()) { + // Skip range-checking for non-finite values. + return true; + } + + double min = MinMaxFiniteValue::min(); + double max = MinMaxFiniteValue::max(); + if (v < min || v > max) { + // Value is out of range for LitearlComplexComponentT. + return Error( + loc, + StrCat(name, " part ", v, + " is out of range for literal's primitive type ", + PrimitiveType_Name( + primitive_util::NativeToPrimitiveType()), + ", namely [", min, ", ", max, "].")); + } + return true; + }; + return check_component("real", std::real(value)) && + check_component("imaginary", std::imag(value)); +} + // operands ::= '(' operands1 ')' // operands1 // ::= /*empty*/ @@ -2753,7 +2904,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( } auto is_unique = [](string str) -> bool { - std::sort(str.begin(), str.end()); + absl::c_sort(str); return std::unique(str.begin(), str.end()) == str.end(); }; @@ -2994,6 +3145,50 @@ bool HloParser::ParseParamList() { return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); } +// dimension_sizes ::= '[' dimension_list ']' +// dimension_list +// ::= /*empty*/ +// ::= <=? int64 (',' param)* +// param ::= name shape +bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes, + std::vector* dynamic_dimensions) { + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + bool is_dynamic = false; + if (lexer_.GetKind() == TokKind::kLeq) { + is_dynamic = true; + lexer_.Lex(); + } + if (!ParseInt64(&i)) { + return false; + } + dimension_sizes->push_back(i); + dynamic_dimensions->push_back(is_dynamic); + return true; + }; + return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, + parse_and_add_item); +} + +// layout ::= '{' int64_list '}' +bool HloParser::ParseLayout(Layout* layout) { + std::vector minor_to_major; + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + minor_to_major.push_back(i); + return true; + }; + if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item)) { + return false; + } + *layout = LayoutUtil::MakeLayout(minor_to_major); + return true; +} + // shape ::= shape_val_ // shape ::= '(' tuple_elements ')' // tuple_elements @@ -3017,19 +3212,67 @@ bool HloParser::ParseShape(Shape* result) { return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); } - if (lexer_.GetKind() != TokKind::kShape) { - return TokenError(absl::StrCat("expected shape, saw ", + if (lexer_.GetKind() != TokKind::kPrimitiveType) { + return TokenError(absl::StrCat("expected primitive type, saw ", TokKindToString(lexer_.GetKind()))); } - *result = lexer_.GetShapeVal(); + PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal(); lexer_.Lex(); + + // Each element contains a dimension size and a bool indicating whether this + // is a dynamic dimension. + std::vector dimension_sizes; + std::vector dynamic_dimensions; + if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) { + return false; + } + result->set_element_type(primitive_type); + for (int i = 0; i < dimension_sizes.size(); ++i) { + result->add_dimensions(dimension_sizes[i]); + result->set_dynamic_dimension(i, dynamic_dimensions[i]); + } + LayoutUtil::SetToDefaultLayout(result); + + if (lexer_.GetKind() == TokKind::kw_sparse) { + lexer_.Lex(); + const string message = + "expects a brace-bracketed integer for sparse layout"; + tensorflow::int64 max_sparse_elements; + if (!ParseToken(TokKind::kLbrace, message) || + !ParseInt64(&max_sparse_elements) || + !ParseToken(TokKind::kRbrace, message)) { + return false; + } + *result->mutable_layout() = + LayoutUtil::MakeSparseLayout(max_sparse_elements); + return true; + } + + // We need to lookahead to see if a following open brace is the start of a + // layout. The specific problematic case is: + // + // ENTRY %foo (x: f32[42]) -> f32[123] { + // ... + // } + // + // The open brace could either be the start of a computation or the start of a + // layout for the f32[123] shape. We consider it the start of a layout if the + // next token after the open brace is a integer + if (lexer_.GetKind() == TokKind::kLbrace && + lexer_.LookAhead() == TokKind::kInt) { + Layout layout; + if (!ParseLayout(&layout)) { + return false; + } + *result->mutable_layout() = layout; + } return true; } bool HloParser::CanBeShape() { - // A non-tuple shape starts with a kShape token; a tuple shape starts with - // '('. - return lexer_.GetKind() == TokKind::kShape || + // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts + // with '('. + return lexer_.GetKind() == TokKind::kPrimitiveType || lexer_.GetKind() == TokKind::kLparen; } @@ -3261,9 +3504,18 @@ bool HloParser::ParseInt64(tensorflow::int64* result) { bool HloParser::ParseDouble(double* result) { switch (lexer_.GetKind()) { - case TokKind::kDecimal: - *result = lexer_.GetDecimalVal(); + case TokKind::kDecimal: { + double val = lexer_.GetDecimalVal(); + // If GetDecimalVal returns +/-inf, that means that we overflowed + // `double`. + if (std::isinf(val)) { + return TokenError(StrCat("Constant is out of range for double (+/-", + std::numeric_limits::max(), + ") and so is unparsable.")); + } + *result = val; break; + } case TokKind::kInt: *result = static_cast(lexer_.GetInt64Val()); break; @@ -3283,6 +3535,42 @@ bool HloParser::ParseDouble(double* result) { return true; } +bool HloParser::ParseComplex(std::complex* result) { + if (lexer_.GetKind() != TokKind::kLparen) { + return TokenError("expects '(' before complex number"); + } + lexer_.Lex(); + + double real; + LocTy loc = lexer_.GetLoc(); + if (!ParseDouble(&real)) { + return Error(loc, + "expect floating-point value for real part of complex number"); + } + + if (lexer_.GetKind() != TokKind::kComma) { + return TokenError( + absl::StrFormat("expect comma after real part of complex literal")); + } + lexer_.Lex(); + + double imag; + loc = lexer_.GetLoc(); + if (!ParseDouble(&imag)) { + return Error( + loc, + "expect floating-point value for imaginary part of complex number"); + } + + if (lexer_.GetKind() != TokKind::kRparen) { + return TokenError(absl::StrFormat("expect ')' after complex number")); + } + + *result = std::complex(real, imag); + lexer_.Lex(); + return true; +} + bool HloParser::ParseBool(bool* result) { if (lexer_.GetKind() != TokKind::kw_true && lexer_.GetKind() != TokKind::kw_false) { @@ -3332,6 +3620,18 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation, return true; } +StatusOr HloParser::ParseShapeOnly() { + lexer_.Lex(); + Shape shape; + if (!ParseShape(&shape)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after shape"); + } + return shape; +} + StatusOr HloParser::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; @@ -3475,4 +3775,9 @@ StatusOr ParsePaddingConfig(absl::string_view str) { return parser.ParsePaddingConfigOnly(); } +StatusOr ParseShape(absl::string_view str) { + HloParser parser(str); + return parser.ParseShapeOnly(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index d830fa61438239005875f785f85cf2486123ebc9..450a54c54c156c2ae27475d145a8e83dc841b431 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -60,6 +60,9 @@ StatusOr ParseConvolutionDimensionNumbers( // Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". StatusOr ParsePaddingConfig(absl::string_view str); +// Parses and returns a Shape::ToString-format string. +StatusOr ParseShape(absl::string_view str); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index ab71f011ac9d77d00ddfb41aca7a224d26d416b7..b525f66f9bc837f720531b8828436ebb9c1f6b31 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -82,7 +82,7 @@ ENTRY %constant_pred () -> pred[] { R"(HloModule module ENTRY %constant_pred_array () -> pred[2,3] { - ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } }) + ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } }) } )" @@ -128,7 +128,7 @@ ENTRY %ConstantF32Empty.v4 () -> f32[0] { R"(HloModule ConstantF32R4Empty_module ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { - ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } }) + ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } }) } )" @@ -139,7 +139,7 @@ ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { R"(HloModule Small_3x2x1x1_module ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { - ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) } )" @@ -196,7 +196,7 @@ ENTRY %add_constants () -> f32[] { R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { - ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { {1}, {2} }, {2, 42} )) + ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} )) } )" @@ -295,11 +295,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { R"(HloModule TwoSendRecvBothWayRecvFist_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1} + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1} ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1} %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} } @@ -310,11 +310,11 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { R"(HloModule HostTransferSendRecv_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, is_host_transfer=true + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, is_host_transfer=true + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true } @@ -327,7 +327,7 @@ R"(HloModule GetTupleElement_module ENTRY %GetTupleElement.v4 () -> s32[2,3] { %constant = f32[3]{0} constant({1, 2, 3}) - %constant.1 = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 4, 5, 6 } }) + %constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }) %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1) ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0} } @@ -434,7 +434,7 @@ ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f R"(HloModule Reverse4DFloatArrayOnDim01_module ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { - %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) + %constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1} } @@ -446,8 +446,8 @@ ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { R"(HloModule Concat2x3With2x5_module ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { - %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) - %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) + %constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } }) + %constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1} } @@ -471,8 +471,8 @@ R"(HloModule R4F32OverlapSmall_module } ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { - %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) - %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) + %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) + %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) %constant.2 = f32[] constant(0) ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3 } @@ -523,7 +523,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { R"(HloModule Slice3x3x3_To_1x3x3_F32_module ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { - %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) + %constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]} } @@ -547,10 +547,21 @@ ENTRY %SliceR0.v2 () -> s32[] { R"(HloModule Transpose_module ENTRY %Transpose.v2 () -> s32[1,2,3] { - %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) + %constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } }) ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2} } +)" +}, +{ +"TransposeC128", +R"(HloModule TransposeC128_module + +ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] { + %input = c128[1,2,3]{2,1,0} parameter(0) + ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2} +} + )" }, // Dynamic slice @@ -566,12 +577,26 @@ ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) - ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258} } +)" +}, +// Dynamic slice with scalar indices +{ +"DynamicSliceScalarIndices", +R"(HloModule DynamicSlice_module + +ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] { + %original_parameter = s32[2,2,258]{2,1,0} parameter(0) + %constant = s32[] constant(0) + %start_index = s32[] parameter(1) + ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258} +} + )" }, // Dynamic update slice { "DynamicUpdateSlice", -R"(HloModule DynamicUpdateSlice_module +R"(HloModule DynamicSlice_module ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] { %input = s32[1,1,25,1]{3,2,1,0} parameter(0) @@ -580,6 +605,23 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices) } +)" +}, +// Dynamic update slice with scalar indices +{ +"DynamicUpdateSliceScalarIndex", +R"(HloModule DynamicUpdateSlice_module + +ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] { + %input = s32[1,1,25,1]{3,2,1,0} parameter(0) + %update = s32[1,1,2,1]{3,2,1,0} parameter(1) + %start_index.0 = s32[] parameter(2) + %start_index.1 = s32[] parameter(3) + %start_index.2 = s32[] parameter(4) + %start_index.3 = s32[] parameter(5) + ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3) +} + )" }, // batch norm training @@ -588,7 +630,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { - %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) + %constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) %constant.1 = f32[2]{0} constant({2, 3}) %constant.2 = f32[2]{0} constant({1, 2}) ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 @@ -728,7 +770,7 @@ R"(HloModule fusion_module } ENTRY %fusion.v3 () -> f32[3,2,1,1] { - %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) %constant.1 = f32[2]{0} constant({3.14, 4.25}) ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation } @@ -740,7 +782,17 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { R"(HloModule sparse_f32 ENTRY %sparse () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) + ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3}) +} + +)" +}, +{ +"SparseC128", +R"(HloModule sparse_c128 + +ENTRY %sparse () -> c128[2,3,4] { + ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)}) } )" @@ -750,7 +802,7 @@ ENTRY %sparse () -> f32[2,3,4] { R"(HloModule sparse_f32_empty ENTRY %sparse_f32_empty () -> f32[2,3,4] { - ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{}) + ROOT %foo = f32[2,3,4]sparse{10} constant({}) } )" @@ -760,7 +812,7 @@ ENTRY %sparse_f32_empty () -> f32[2,3,4] { R"(HloModule sparse_f32_r1 ENTRY %sparse_f32_r1 () -> f32[9] { - ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) + ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6}) } )" @@ -852,6 +904,28 @@ ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123 ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}} } +)" +}, +// Parse c64 literal +{ +"ParseC64Literal", +R"(HloModule ParseC64Literal + +ENTRY %ParseC64Literal () -> c64[2] { + ROOT %c = c64[2]{0} constant({(1, 2), (-inf, nan)}) +} + +)" +}, +// Parse c128 literal +{ +"ParseC128Literal", +R"(HloModule ParseC128Literal + +ENTRY %ParseC128Literal () -> c128[2] { + ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)}) +} + )" }, }); @@ -931,11 +1005,11 @@ ENTRY reduce_entry { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - token = token[] after-all() - infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + token0 = token[] after-all() + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) - ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + outfeed = token[] outfeed(infeed.data, token0) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 infeed.1.token = token[] get-tuple-element(infeed.1), index=1 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) @@ -1117,9 +1191,9 @@ ENTRY Gather { )" }, -// cross-replica-sum +// all-reduce { -"CrossReplicaSum", +"AllReduce", R"(HloModule CRS add { @@ -1130,14 +1204,14 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - ROOT crs = f32[8]{0} cross-replica-sum(input), replica_groups={}, to_apply=add + ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, to_apply=add } )" }, -// cross-replica-sum with subgroups +// all-reduce with subgroups { -"CrossReplicaSumWithSubgroups", +"AllReduceWithSubgroups", R"(HloModule CRS_Subgroups add { @@ -1146,16 +1220,16 @@ add { ROOT add = f32[] add(lhs, rhs) } -ENTRY CrossReplicaSumWithSubgroups { +ENTRY AllReduceWithSubgroups { input = f32[128,32]{0,1} parameter(0) - ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add + ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add } )" }, -// cross-replica-sum with all-reduce-id +// all-reduce with all-reduce-id { -"CrossReplicaSumAllReduce", +"AllReduceAllReduce", R"(HloModule CRS add { @@ -1166,8 +1240,8 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - crs.1 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add - ROOT crs.0 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + crs.1 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + ROOT crs.0 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add } )" @@ -1266,12 +1340,36 @@ R"(HloModule AddDependency ENTRY AddDependency { p = f32[] parameter(0) neg = f32[] negate(p) - token = token[] after-all(neg) - p_after_token = f32[] add-dependency(p, token) + token0 = token[] after-all(neg) + p_after_token = f32[] add-dependency(p, token0) exp = f32[] exponential(p_after_token) ROOT sum = f32[] add(neg, exp) } +)" +}, + +// A module containing constants equal to the min/max values of various data +// types. +{ +"MinMaxValues", +R"(HloModule MinMaxValues + +ENTRY MinMaxValues { + x.s8 = s8[2]{0} constant({-128, 127}) + x.s16 = s16[2]{0} constant({-32768, 32767}) + x.s32 = s32[2]{0} constant({-2147483648, 2147483647}) + x.u8 = u8[2]{0} constant({0, 255}) + x.u16 = u16[2]{0} constant({0, 65535}) + x.u32 = u32[2]{0} constant({0, 4294967295}) + x.f16 = f16[2]{0} constant({-65504, 65504}) + x.bf16 = bf16[2]{0} constant({-3.38953e+38, 3.38953e+38}) + x.f32 = f32[2]{0} constant({-3.40282e+38, 3.40282e+38}) + x.f64 = f64[2]{0} constant({-1.79769e+308, 1.79769e+308}) + x.c64 = c64[2]{0} constant({(-3.40282e+38, 3.40282e+38), (3.40282e+38, -3.40282e+38)}) + ROOT c.c128 = c128[2]{0} constant({(-1.79769e+308, 1.79769e+308), (1.79769e+308, -1.79769e+308)}) +} + )" }, }); @@ -1329,20 +1427,20 @@ TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); } TEST_P(HloParserTestShort, Run) { ExpectEqual(); } TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); } -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestLong, - ::testing::ValuesIn(CreateTestCases()), - TestDataToString); -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, - HloParserTestLongProto, - ::testing::ValuesIn(CreateTestCases()), - TestDataToString); -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestShort, - ::testing::ValuesIn(CreateShortTestCases()), - TestDataToString); -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, - HloParserTestShortProto, - ::testing::ValuesIn(CreateShortTestCases()), - TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestLong, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, + HloParserTestLongProto, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestShort, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); +INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, + HloParserTestShortProto, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); class HloParserTest : public ::testing::Test { protected: @@ -1419,7 +1517,7 @@ TEST_F(HloParserTest, MoreConstants) { ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) - %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4} + %constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4} %constant = s32[] constant(42) %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) } @@ -1462,7 +1560,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_2) { const string original = R"(HloModule some_2x3_module ENTRY %some_2x3 () -> f32[2,3] { - ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6}) + ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6}) } )"; @@ -1476,7 +1574,7 @@ TEST_F(HloParserTest, LiteralDimensionsMismatch_3) { const string original = R"(HloModule some_2x3x2_module ENTRY %some_2x3x2 () -> f32[2,3,2] { - ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) + ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) } )"; @@ -1501,6 +1599,37 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { "is out of range for literal's primitive type F16"); } +TEST_F(HloParserTest, ConstantBf16NoOverflow) { + // 65505 is in range for bf16. + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = bf16[] constant(-65505) + })"; + EXPECT_EQ(Status::OK(), ParseHloString(original).status()); +} + +TEST_F(HloParserTest, ConstantBf16Overflow) { + // 1e100 is out of range for bf16. + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = bf16[] constant(1e100) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "out of range"); +} + +TEST_F(HloParserTest, ConstantF16OverflowInSparseArray) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({[0]: 0, [1]: -65505}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "is out of range for literal's primitive type F16"); +} + TEST_F(HloParserTest, ConstantUnsignedUnderflow) { const string original = R"( HloModule ConstantUnsignedUnderflow_module @@ -1535,6 +1664,46 @@ TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) { EXPECT_NE(Status::OK(), result.status()); } +TEST_F(HloParserTest, ConstantC64Overflow) { + const string original = R"( + HloModule test_module + ENTRY test () -> c64[] { + ROOT c = c64[] constant((1e100, 0)) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + +TEST_F(HloParserTest, ConstantC64Underflow) { + const string original = R"( + HloModule test_module + ENTRY test () -> c64[] { + ROOT c = c64[] constant((0, -1e100)) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + +TEST_F(HloParserTest, ConstantF64Overflow) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f64[] constant(1.8e308) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + +TEST_F(HloParserTest, ConstantF64Underflow) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f64[] constant(-1.8e308) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + TEST_F(HloParserTest, ConstantWithExp) { const string original = R"(HloModule ConstantWithExp_module @@ -1550,6 +1719,19 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { // printed as "300". } +TEST_F(HloParserTest, ShortConstant) { + const string original = R"(HloModule ShortCOnstant_module + +ENTRY %ShortConstant.v4 () -> f32[67,89] { + ROOT %constant.1 = f32[67,89]{1,0} constant({...}) +} + +)"; + auto result = ParseHloString(original); + TF_EXPECT_OK(result.status()); + EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original); +} + TEST_F(HloParserTest, AttibutesAnyOrder) { const string original = R"(HloModule any_order_module @@ -1594,11 +1776,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) { const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1611,11 +1793,11 @@ TEST_F(HloParserTest, MissingAttribute) { const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(-2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token) + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0) %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1628,11 +1810,11 @@ TEST_F(HloParserTest, PredecessorUndefined) { const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %token = token[] after-all() - %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %token0 = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done} %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } @@ -1940,8 +2122,8 @@ TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) { TEST_F(HloParserTest, NontupleInfeed) { const string original = R"(HloModule nontuple_infeed: ENTRY nontuple_infeed { - token = token[] after-all() - ROOT infeed = pred[] infeed(token) + token0 = token[] after-all() + ROOT infeed = pred[] infeed(token0) })"; ExpectHasSubstr(ParseHloString(original).status().error_message(), "infeed must have a non-empty tuple shape"); @@ -2239,7 +2421,7 @@ HloModule foobar ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { %p = f32[2,2] parameter(0) - %constant.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.1 = f32[2,2] constant({{1, 2}, {3, 4}}) ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1) } )"; @@ -2249,7 +2431,153 @@ ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { " with the shape of the operand instruction f32[2,2]{1,0}."); } -// custom call incompatible shape. +TEST_F(HloParserTest, OutOfRangeSparseIndex) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({[100]: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + +TEST_F(HloParserTest, NegativeSparseIndex) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({-1: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + +TEST_F(HloParserTest, SparseIndexWithRankTooLarge) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5]sparse{10} constant({[0, 0]: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + +TEST_F(HloParserTest, SparseIndexWithRankTooSmall) { + const string original = R"( + HloModule test_module + ENTRY test { + ROOT c = f16[5, 5]sparse{10} constant({[0]: 0}) + })"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Invalid sparse index"); +} + +TEST_F(HloParserTest, ParseShapeStringR2F32) { + string shape_string = "f32[123,456]"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) { + string shape_string = "(f32[1572864],s8[5120,1024])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), + ShapeUtil::MakeShape(S8, {5120, 1024})}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringNestedTuple) { + string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {1}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeOpaqueShape(), + ShapeUtil::MakeShape(F32, {3}), + }); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringWithLayout) { + string shape_string = "f32[123,456]{0,1}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) { + string shape_string = "f32[123,456]sparse{10}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseOpaqueType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]")); + Shape expected = ShapeUtil::MakeOpaqueShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseTokenType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]")); + Shape expected = ShapeUtil::MakeTokenShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseInvalidShapeString) { + string shape_strings[] = { + "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", + "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", + }; + for (const string& shape_string : shape_strings) { + StatusOr result = ParseShape(shape_string); + ASSERT_FALSE(result.ok()) << "shape: " << shape_string; + } +} + +TEST_F(HloParserTest, ParseDynamicArray) { + string shape_string = "f32[123,<=456]"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeShape(F32, {123, 456}, {false, true}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, ParseDynamicTuple) { + string shape_string = "(f32[42], u32[<=123,<=456])"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {42}), + ShapeUtil::MakeShape(U32, {123, 456}, {true, true})}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST_F(HloParserTest, NegativeParameterNumber) { + const string hlo_string = "par0 = f32[3,5] parameter(-1)"; + auto result = ParseHloString(hlo_string); + ASSERT_FALSE(result.status().ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("parameter number must be >= 0")); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index 791b1a97b0b82edf19ff1588fd8d5d996ac0fef4..35dc9c0029f9871334cb500c6b71f0c86ab136d7 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -39,9 +40,36 @@ class HloPassFix : public Pass { int64 iteration_count = 0; int64 limit = std::max(static_cast(1000), module->instruction_count()); + VLOG(3) << "Running HloPassFix."; while (changed_this_iteration) { TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); changed |= changed_this_iteration; + VLOG(3) << "changed_this_iteration: " << changed_this_iteration; + ++iteration_count; + if (iteration_count == limit) { + LOG(ERROR) + << "Unexpectedly high number of iterations in HLO passes (" + << iteration_count + << ")\nIf compilation hangs here, please file a bug with XLA."; + } + } + return changed; + } + + StatusOr RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + bool changed_this_iteration = true; + int64 iteration_count = 0; + int64 limit = 1000; + for (const HloModule* module : module_group->modules()) { + limit = std::max(limit, module->instruction_count()); + } + VLOG(3) << "Running HloPassFix."; + while (changed_this_iteration) { + TF_ASSIGN_OR_RETURN(changed_this_iteration, + Pass::RunOnModuleGroup(module_group)); + changed |= changed_this_iteration; + VLOG(3) << "changed_this_iteration: " << changed_this_iteration; ++iteration_count; if (iteration_count == limit) { LOG(ERROR) diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 312b5d020c398feb7738d14a9cfa0928d5178948..ae8c08cf1d16ad6738962f3be7c1b5512110b1d1 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -77,6 +77,11 @@ std::vector HloPassPipeline::GetEnabledPasses( auto repeated_field = debug_options.xla_disable_hlo_passes(); absl::flat_hash_set disabled_pass_names(repeated_field.begin(), repeated_field.end()); + if (debug_options.xla_disable_all_hlo_passes()) { + VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes."; + return {}; + } + if (!disabled_pass_names.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " << absl::StrJoin(disabled_pass_names, ", "); @@ -84,7 +89,7 @@ std::vector HloPassPipeline::GetEnabledPasses( std::vector enabled_passes; for (auto& pass : passes_) { - if (disabled_pass_names.count(string(pass->name())) == 0) { + if (!disabled_pass_names.contains(pass->name())) { enabled_passes.push_back(pass.get()); } } @@ -113,7 +118,7 @@ void HloPassPipeline::MaybeDumpHlo(const HloModule& module, } const string message = - StrCat("after ", after_pass_name, ", before ", before_pass_name); + absl::StrCat("after ", after_pass_name, ", before ", before_pass_name); hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; VLOG(3) << module.entry_computation_layout().ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc index 5eb707a957e49d86cdb2f72b72ce750bf29b8fd2..9cc202aa9f5fe5a20a9da05251ea811137ccaadb 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.cc +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" @@ -34,11 +35,10 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, for (const HloComputationInfo& computation_info : hlo_profile_printer_data.computation_infos()) { const auto& instruction_infos = computation_info.instruction_infos(); - bool any_instruction_profiled = - std::any_of(instruction_infos.begin(), instruction_infos.end(), - [&](const HloInstructionInfo& instruction_info) { - return counters[instruction_info.profile_index()] != 0; - }); + bool any_instruction_profiled = absl::c_any_of( + instruction_infos, [&](const HloInstructionInfo& instruction_info) { + return counters[instruction_info.profile_index()] != 0; + }); if (!any_instruction_profiled) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 981d06ce101644ecce587c4bd2f7a12c8edf6548..3a9ee57e5551ae5b608f02d9f8bd0428ff16db13 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -39,6 +39,7 @@ HloProto MakeHloProto(const HloModule& module) { StatusOr> CreateModuleFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(4) << proto.ShortDebugString(); TF_ASSIGN_OR_RETURN(std::unique_ptr module, HloModule::CreateFromProto(proto, module_config)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 4aa8067752481ffab29e1a573ffa49d4aa046f1f..0fced7f15bdaf1dbe349e3b0fc6ada68393c6512 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -49,7 +49,7 @@ void HloReachabilityMap::SetReachabilityToUnionHelper( absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. - if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { + if (!absl::c_linear_search(inputs, instruction)) { bit_vector->SetToZero(); } bit_vector->Set(GetIndex(instruction)); @@ -93,7 +93,7 @@ std::unique_ptr HloReachabilityMap::Build( } break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kAllReduce: { auto all_reduce_id = hlo->all_reduce_id(); if (all_reduce_id) { auto it = channel_dependency_map.find(all_reduce_id.value()); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 48add75523f02005c70bc6baf69a6b7d5aa4f7ef..a175e4643de2ac6ce07ac00da914d7ab7acca541 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -57,13 +57,22 @@ using ::tensorflow::strings::HumanReadableNumBytes; // Returns true if the given instruction is rematerializable. bool IsRematerializable(const HloInstruction* instruction) { + if (instruction->opcode() == HloOpcode::kCopy) { + if (LayoutUtil::Equal(instruction->shape().layout(), + instruction->operand(0)->shape().layout())) { + // Don't rematerialize copies added by copy insertion (layout doesn't + // change). + return false; + } + } + // Don't rematerialize instructions with side effects or instructions which // cannot be cloned safely. switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kConstant: case HloOpcode::kConditional: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kCustomCall: case HloOpcode::kParameter: case HloOpcode::kWhile: @@ -179,7 +188,8 @@ class InstructionList { Item* CreateItem(HloInstruction* inst) { Item* item = new Item; item->instruction = inst; - CHECK(item_map_.insert({inst, item}).second) << "inserting inst twice"; + CHECK(item_map_.insert({inst, item}).second) + << "inserting inst twice " << inst->name(); return item; } @@ -235,8 +245,7 @@ class InstructionList { } // Now scan forwards until we find one of the before_instructions. - while (std::find(before_instructions.begin(), before_instructions.end(), - min_position_item) == before_instructions.end()) { + while (!absl::c_linear_search(before_instructions, min_position_item)) { min_position_item = min_position_item->next; } return InsertBefore(to_insert, min_position_item); @@ -302,7 +311,7 @@ ItemList GetUsers(const InstructionList& instruction_list, // A buffer may be used by the instruction via more than one alias. For // example, a buffer which appears in more than one element of a tuple. Item* user_item = instruction_list.GetItem(user); - if (std::find(users.begin(), users.end(), user_item) == users.end()) { + if (!absl::c_linear_search(users, user_item)) { users.push_back(user_item); } } @@ -418,11 +427,12 @@ class MemoryUsageTracker { // the given uses. Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item, ItemList&& rematerialized_uses) { - CHECK(original_buffer.defining_instruction->placed); - CHECK(!original_buffer.has_indirect_uses); - CHECK(!original_buffer.live_out); + CHECK(original_buffer.defining_instruction->placed) + << original_buffer.defining_instruction->instruction->name(); + CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString(); + CHECK(!original_buffer.live_out) << original_buffer.ToString(); for (Item* use : rematerialized_uses) { - CHECK(!use->placed); + CHECK(!use->placed) << use->instruction->name(); } return NewBuffer(remat_item, original_buffer.size, std::move(rematerialized_uses), /*live_out=*/false, @@ -456,8 +466,7 @@ class MemoryUsageTracker { return false; } const BufferIdList& in_progress_uses = in_progress_item_->buffers_used; - return std::find(in_progress_uses.begin(), in_progress_uses.end(), - buffer_id) != in_progress_uses.end(); + return absl::c_linear_search(in_progress_uses, buffer_id); } // Returns whether the given instruction is live at the current program @@ -535,8 +544,7 @@ MemoryUsageTracker::MemoryUsageTracker( bool unused; for (Item* user_item : GetUsers(instruction_list_, logical_buffer, points_to_analysis, &unused)) { - if (std::find(buffer->users.begin(), buffer->users.end(), - user_item) == buffer->users.end()) { + if (!absl::c_linear_search(buffer->users, user_item)) { buffer->users.push_back(user_item); buffer->unfinished_user_count++; user_item->buffers_used.push_back(buffer->id); @@ -677,8 +685,8 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, << ", remat_instruction = " << remat_item->instruction->name(); TF_RET_CHECK(in_progress_item_ != nullptr); - TF_RET_CHECK(original_item->placed); - TF_RET_CHECK(!remat_item->placed); + TF_RET_CHECK(original_item->placed) << original_item->instruction->name(); + TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name(); // Construct the list of buffers used and defined by the rematerialization. remat_item->buffers_used = original_item->buffers_used; @@ -707,7 +715,7 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, ItemList unplaced_users; for (Item* user : old_buffer.users) { if (user->placed) { - CHECK(IsFinished(user)); + CHECK(IsFinished(user)) << user->instruction->name(); placed_users.push_back(user); } else { unplaced_users.push_back(user); @@ -784,8 +792,7 @@ bool MemoryUsageTracker::Check() const { for (const Buffer& buffer : buffers_) { if (buffer.defining_instruction->instruction == instruction) { - CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), - buffer.id) != defined_buffers.end()) + CHECK(absl::c_linear_search(defined_buffers, buffer.id)) << "Instruction " << instruction->name() << " defined buffers is missing: " << buffer.ToString(); } @@ -808,8 +815,7 @@ bool MemoryUsageTracker::Check() const { int64 unfinished_uses = 0; for (Item* user : buffer.users) { const BufferIdList& used_buffers = user->buffers_used; - CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != - used_buffers.end()) + CHECK(absl::c_linear_search(used_buffers, buffer.id)) << "Instruction " << user->instruction->name() << " used buffers is missing " << buffer.ToString(); if (!IsFinished(user)) { @@ -836,10 +842,10 @@ int64 RematerializationCost(const HloInstruction* instruction, // If none of the users of 'instruction' have been placed in the sequence (as // tracked by memory_tracker), then rematerialization of 'instruction' is a // zero-cost move of 'instruction' in the sequence. - if (!std::any_of(instruction->users().begin(), instruction->users().end(), - [&memory_tracker](const HloInstruction* inst) { - return memory_tracker.IsPlaced(inst); - })) { + if (!absl::c_any_of(instruction->users(), + [&memory_tracker](const HloInstruction* inst) { + return memory_tracker.IsPlaced(inst); + })) { return 0; } @@ -1094,7 +1100,7 @@ StatusOr HloRematerialization::RematerializeComputation( Item* successor_item = instruction_list.GetItem(successor); // Assert to make sure we never remat an operation with control // successor already placed. - CHECK(!successor_item->placed); + CHECK(!successor_item->placed) << successor_item->instruction->name(); place_before.push_back(successor_item); } instruction_list.InsertBeforeInstructions(remat_item, place_before); @@ -1164,7 +1170,7 @@ StatusOr HloRematerialization::RematerializeComputation( // Verify some invariants on the memory tracker. CHECK_EQ(memory_tracker.memory_usage(), 0); for (auto* instruction : computation->instructions()) { - CHECK(memory_tracker.IsPlaced(instruction)); + CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name(); } VLOG(1) << "In computation " << computation->name() << " rematerialized " diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 22c3c40a93a1ddcd36659483fcc79fede32dd2c3..102a360ad8116d8781baf9cb7627a920f4a687c4 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -499,6 +499,52 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_THAT(add_4->operand(0), op::Broadcast(param)); } +TEST_F(HloRematerializationTest, CopyNotRematerialized) { + // Test that copies are not rematerialized. + auto module = CreateNewVerifiedModule(); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1024_shape_, "param")); + + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kCopy, param)); + + auto negate_a_1 = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy)); + + auto negate_a_2 = builder.AddInstruction(HloInstruction::CreateUnary( + vec1024_shape_, HloOpcode::kNegate, negate_a_1)); + + auto negate_b_1 = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, copy)); + + auto negate_b_2 = builder.AddInstruction(HloInstruction::CreateUnary( + vec1024_shape_, HloOpcode::kNegate, negate_b_1)); + + builder.AddInstruction(HloInstruction::CreateTuple({negate_a_2, negate_b_2})); + + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/1 * 1024, module.get())); + + auto count_copies = [](const HloComputation* computation) { + int64 copy_count = 0; + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + copy_count++; + } + } + return copy_count; + }; + EXPECT_TRUE(changed); + + EXPECT_EQ(count_copies(entry_computation), 1); +} + class IndirectUseTest : public HloRematerializationTest, public ::testing::WithParamInterface {}; @@ -588,8 +634,8 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { } } -INSTANTIATE_TEST_CASE_P(IndirectUseTestInstantiation, IndirectUseTest, - ::testing::Values(true, false)); +INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest, + ::testing::Values(true, false)); } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 5a9b820a9d7f58695383b21c9e2126cf98970c83..d7d66ae1c4592723ca991d5ee971fa72cc1af90a 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -383,9 +383,7 @@ ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( if (device_assignment != nullptr) { run_options.set_device_assignment(device_assignment); } - return ServiceExecutableRunOptions( - run_options, backend().StreamBorrower(), - /*xla_intra_op_thread_pool=*/backend().eigen_intra_op_thread_pool()); + return ServiceExecutableRunOptions(run_options, backend().StreamBorrower()); } Backend& HloRunner::backend() { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index 8f6eb974c5179b420c8f961393ca923e0a3b3530..e75373501cffac6a736be89e9f6139b6ff2cdbc1 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -140,7 +140,7 @@ Status HloSchedule::UpdateComputationSchedule( std::queue worklist; for (HloInstruction* instruction : computation->instructions()) { - if (ids_in_schedule.count(instruction->unique_id()) == 0) { + if (!ids_in_schedule.contains(instruction->unique_id())) { // This is a newly added instruction which is not in the schedule. if (instruction->operands().empty()) { worklist.push(instruction); @@ -204,7 +204,7 @@ Status HloSchedule::Update() { std::vector nonfusion_computations = module_->MakeNonfusionComputations(); for (const HloComputation* computation : nonfusion_computations) { - TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + TF_RET_CHECK(sequences_.contains(computation->unique_id())) << "Computation " << computation->name() << " not in HloSchedule."; } if (sequences_.size() > nonfusion_computations.size()) { @@ -215,7 +215,7 @@ Status HloSchedule::Update() { nonfusion_computations_ids.insert(computation->unique_id()); } for (auto it = sequences_.begin(); it != sequences_.end();) { - if (nonfusion_computations_ids.count(it->first) == 0) { + if (!nonfusion_computations_ids.contains(it->first)) { sequences_.erase(it++); } else { ++it; @@ -244,7 +244,7 @@ Status HloSchedule::Verify() const { << "Schedule has " << sequences_.size() << " sequences, but module has " << nonfusion_computations.size() << " non-fusion computations"; for (const HloComputation* computation : nonfusion_computations) { - TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + TF_RET_CHECK(sequences_.contains(computation->unique_id())) << "Computation " << computation->name() << " missing from HLO schedule."; } @@ -268,7 +268,7 @@ Status HloSchedule::Verify() const { << instruction_position.size() << " instructions, expected " << computation->instruction_count(); for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(instruction_position.count(instruction) == 1) + TF_RET_CHECK(instruction_position.contains(instruction)) << "Instruction " << instruction->name() << " is not in schedule"; } diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 486ddbf499de80c634bc497158cd79ca066cc866..a5f54ae2c33259d080631061dff9ae40b41495dc 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -110,7 +110,7 @@ class HloSchedule { // Returns true if the schedule has a sequence for the given computation. bool is_computation_scheduled(const HloComputation* computation) const { - return sequences_.count(computation->unique_id()) == 1; + return sequences_.contains(computation->unique_id()); } // Updates the schedule such that it is (again) a valid schedule for the diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 70a860c356ca2fb1c4c973ea3d96c50fabc2c7c2..37cc146bd7a6f2aef9373bd4afd8572ffac6473c 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/overflow_util.h" @@ -30,7 +31,7 @@ HloSharding HloSharding::AssignDevice(int64 device_id) { } HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { - CHECK_EQ(1, ShapeUtil::Rank(input_shape)); + CHECK_EQ(1, input_shape.rank()); CHECK_GT(num_tiles, 1); std::vector dimensions(1, num_tiles); Array assignment(dimensions); @@ -57,7 +58,7 @@ HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { HloSharding HloSharding::Tuple(const Shape& tuple_shape, absl::Span shardings) { - CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape); for (auto& sharding : shardings) { CHECK(!sharding.IsTuple()) << sharding.ToString(); } @@ -70,7 +71,7 @@ HloSharding HloSharding::Tuple(const Shape& tuple_shape, HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, const HloSharding& sharding) { - CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + CHECK(tuple_shape.IsTuple()) << ShapeUtil::HumanString(tuple_shape); CHECK(!sharding.IsTuple()) << sharding.ToString(); int64 leaf_count = RequiredLeaves(tuple_shape); std::vector flattened_list; @@ -80,7 +81,7 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, HloSharding HloSharding::Single(const Shape& shape, const HloSharding& sharding) { - return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding; + return shape.IsTuple() ? SingleTuple(shape, sharding) : sharding; } string HloSharding::ToString() const { @@ -106,13 +107,12 @@ string HloSharding::ToString() const { bool HloSharding::UsesDevice(int64 device) const { if (IsTuple()) { - return std::any_of( - tuple_elements_.begin(), tuple_elements_.end(), - [&](const HloSharding& s) { return s.UsesDevice(device); }); + return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) { + return s.UsesDevice(device); + }); } const auto& devices = tile_assignment_; - return replicated_ || - std::find(devices.begin(), devices.end(), device) != devices.end(); + return replicated_ || absl::c_linear_search(devices, device); } std::map HloSharding::UsedDevices(int64* count) const { @@ -269,7 +269,7 @@ int64 HloSharding::GetUniqueDevice() const { } Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { - if (!ShapeUtil::IsTuple(shape)) { + if (!shape.IsTuple()) { return tensorflow::errors::InvalidArgument( StrCat("Sharding is tuple-shaped but validation shape is not.")); } @@ -305,7 +305,7 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { Status HloSharding::ValidateNonTuple(const Shape& shape, int64 num_devices) const { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { return tensorflow::errors::InvalidArgument( StrCat("Validation shape is a tuple but sharding is not.")); } @@ -316,7 +316,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, // All tile assignments must be less than the number of available cores and // unique. Status status = Status::OK(); - std::set seen_cores; + absl::flat_hash_set seen_cores; tile_assignment_.Each( [&](absl::Span indices, int32 core) { // Don't overwrite a bad status, so we report the first error. @@ -324,7 +324,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, if (core >= num_devices) { status = tensorflow::errors::InvalidArgument(StrCat( "core ", core, " > ", num_devices, " in tile assignment")); - } else if (seen_cores.count(core) != 0) { + } else if (seen_cores.contains(core)) { status = tensorflow::errors::InvalidArgument( StrCat("core ", core, " is not unique in tile assignment")); } @@ -340,7 +340,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, } // The tile assignment tensor must have the same rank as the input. - if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) { + if (shape.rank() != tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( "Number of tile assignment dimensions is different to the input rank. " "sharding=", @@ -437,8 +437,8 @@ Shape HloSharding::TileShape(const Shape& shape) const { } Shape result_shape = shape; for (int64 i = 0; i < shape.dimensions_size(); ++i) { - (*result_shape.mutable_dimensions())[i] = - CeilOfRatio(shape.dimensions(i), tile_assignment_.dim(i)); + result_shape.set_dimensions( + i, CeilOfRatio(shape.dimensions(i), tile_assignment_.dim(i))); } return result_shape; } @@ -455,7 +455,7 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, } sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx}); } - if (ShapeUtil::IsTuple(*sub_shape)) { + if (sub_shape->IsTuple()) { auto begin_it = tuple_elements_.begin() + sharding_index; std::vector sub_shardings( begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape)); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 9775505f8608ced3e33abe376f4922cc6a972726..5789ae09988d2a85247c5b8c037a172b3699f3b7 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -101,8 +101,8 @@ class HloSharding { if (!IsTuple()) { return replicated_; } - return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), - [](const HloSharding& s) { return s.IsReplicated(); }); + return absl::c_all_of( + tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); }); } // Returns true if the tile size is the same as the input size. @@ -110,8 +110,9 @@ class HloSharding { if (!IsTuple()) { return maximal_; } - return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), - [](const HloSharding& s) { return s.IsTileMaximal(); }); + return absl::c_all_of(tuple_elements_, [](const HloSharding& s) { + return s.IsTileMaximal(); + }); } // Returns true if the sharding defines an operation on the given device. diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index f5061304456e04ab40448861343ef201c9450dcf..094d98bc6e54028557f6d38cd165bf34e1fb8c46 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -99,7 +99,7 @@ std::vector LocatePassThroughDomainLinks( << "Instruction is not a kDomain: " << instruction->ToString(); for (HloInstruction* user : instruction->users()) { if (user->opcode() == HloOpcode::kDomain && - domain.exit_domains.count(user) != 0) { + domain.exit_domains.contains(user)) { pass_through.emplace_back(user, instruction); VLOG(2) << "Found passthrough domain link:"; VLOG(2) << " " << user->ToString(); @@ -234,7 +234,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, if (instruction->users().empty()) { // No sharding from users, use domain_sharding, after checking // compatibility. - TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()) && + TF_RET_CHECK(instruction->shape().IsTuple() && ShapeUtil::GetLeafCount(instruction->shape()) == domain_sharding.tuple_elements().size()); instruction->set_sharding(domain_sharding); @@ -253,7 +253,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); for (HloInstruction* user : instruction->users()) { if (user->opcode() == HloOpcode::kDomain && - domain.exit_domains.count(user) > 0) { + domain.exit_domains.contains(user)) { // If a user is a domain and it is registered in the domain exits, then // the instruction sharding is taken directly from the domain, and no // further users need to be visited. @@ -266,7 +266,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, AssignmentKind sub_assigned = AssignmentKind::kUnassigned; TF_ASSIGN_OR_RETURN(ShapeTree user_sharding_tree, GetShardingTreeFromUser(*instruction, *user)); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { // For tuple-shaped instructions collect individual tuple subshardings // from the uses, and then combine them into the tuple sharding. // If the user is a GTE its sharding concerns only the subtree of @@ -298,7 +298,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, } if (assigned == AssignmentKind::kAssigned) { - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { instruction->set_sharding(HloSharding::Tuple(sharding_tree)); } else { TF_RET_CHECK(sharding_tree.leaf_count() == 1); @@ -361,7 +361,7 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, // kUnassignedDevice. Indeed in case of doubt it is better to leave the // entire tuple unassigned, and let the device placer decide for it. if (instruction->sharding().UsesDevice(kUnassignedDevice)) { - TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape())) + TF_RET_CHECK(instruction->shape().IsTuple()) << "Only tuples can have kUnassignedDevice sub shardings"; instruction->clear_sharding(); } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 487653344976a10e18ba667085525ba1ecbb8612..c1f69db74eafb7743e85f499f2f4828ed0375501 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -61,8 +61,7 @@ void CleanNodeName(string* name) { name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); const string chars_to_replace = "<>[]"; auto pred = [&](char c) { - return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) != - chars_to_replace.end(); + return absl::c_linear_search(chars_to_replace, c); }; std::replace_if(name->begin(), name->end(), pred, '_'); } @@ -159,7 +158,7 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, // Set the layout. if (LayoutUtil::HasLayout(instruction->shape())) { string layout_string; - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { // For tuples, emit the full shape because the layout of a tuple is not // represented in a single Layout field. layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h deleted file mode 100644 index 4458c251dee4af365e39027dd4289925c8890efd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_token.h +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Defines different kinds of tokens in a hlo module string. -// -// You shouldn't need to use this directly unless you're using HloLexer -// directly, and you probably don't need to do that. Use hlo_parser instead. -enum class TokKind { - // Markers - kEof, - kError, - - // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : - kLsquare, - kRsquare, // [ ] - kLbrace, - kRbrace, // { } - kLparen, - kRparen, // ( ) - - kArrow, // -> - - // Keywords - kw_HloModule, - kw_ENTRY, - kw_ROOT, - kw_true, - kw_false, - kw_maximal, - kw_replicated, - kw_nan, - kw_inf, - - kNegInf, // -inf - - // Typed tokens. - kName, // %foo - kAttributeName, // dimensions= - kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} - kDxD, // [0-9]+(x[0-9]+)+ - kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* - kIdent, // other identifiers - kString, // "abcd\"\n" - kShape, // f32[2,3]{1,0} - kInt, // 42 - kDecimal, // 4.2 -}; - -string TokKindToString(TokKind kind); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 59594ab2f0f70a206c73e998dbfa69c2c5c7ba43..218b33b2ac2b86edc30b2f014ba206c71da37682 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -46,7 +46,7 @@ const Shape& HloPosition::shape() const { string HloPosition::ToString() const { string index_str = - ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : ""; + instruction->shape().IsTuple() ? (" " + index.ToString()) : ""; return StrCat(instruction->name(), index_str); } @@ -56,10 +56,9 @@ std::ostream& operator<<(std::ostream& out, const HloPosition& position) { } string HloUse::ToString() const { - string index_str = - ShapeUtil::IsTuple(instruction->operand(operand_number)->shape()) - ? (" " + operand_index.ToString()) - : ""; + string index_str = instruction->operand(operand_number)->shape().IsTuple() + ? (" " + operand_index.ToString()) + : ""; return StrCat(instruction->name(), ", operand ", operand_number, index_str); } @@ -88,7 +87,7 @@ bool HloValue::operator!=(const HloValue& other) const { } string HloValue::ToShortString() const { - string index_str = ShapeUtil::IsTuple(defining_instruction()->shape()) + string index_str = defining_instruction()->shape().IsTuple() ? defining_index().ToString() : ""; return StrCat(id(), " ", is_phi_ ? "PHI " : "", @@ -210,7 +209,7 @@ std::ostream& operator<<(std::ostream& out, const HloValue& value) { } void HloValueSet::SortAndUniquifyValues() { - std::sort(values_.begin(), values_.end(), HloValue::IdLessThan); + absl::c_sort(values_, HloValue::IdLessThan); values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual), values_.end()); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 77db7b098a38ff4efdcc7447935fae61561c9ff4..4caaa5a32b1e213ff475591e32809f744bcb86ad 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -44,7 +44,7 @@ bool IsCallerInstruction(HloInstruction* hlo) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kWhile: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: @@ -153,8 +153,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { const Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->feature_group_count(), convolution->window(), - convolution->convolution_dimension_numbers())); + convolution->feature_group_count(), convolution->batch_group_count(), + convolution->window(), convolution->convolution_dimension_numbers())); return CheckShape(convolution, expected); } @@ -167,13 +167,12 @@ Status ShapeVerifier::HandleFft(HloInstruction* fft) { return CheckShape(fft, expected); } -Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { +Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) { std::vector operand_shapes; for (const HloInstruction* operand : crs->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(crs, - ShapeInference::InferCrossReplicaSumShape(operand_shapes)); + return CheckShape(crs, ShapeInference::InferAllReduceShape(operand_shapes)); } Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { @@ -185,6 +184,10 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { ShapeInference::InferAllToAllTupleShape(operand_shapes)); } +Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) { + return CheckShape(hlo, ShapeUtil::MakeShape(U32, {})); +} + Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( @@ -350,7 +353,10 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { Status ShapeVerifier::HandleIota(HloInstruction* instruction) { TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); auto* iota = Cast(instruction); - const int64 rank = ShapeUtil::Rank(iota->shape()); + if (!iota->shape().IsArray()) { + return InternalError("Iota does not support non-array result."); + } + const int64 rank = iota->shape().rank(); if (rank == 0) { return InternalError("Iota does not support scalars."); } @@ -388,6 +394,14 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { TF_RETURN_IF_ERROR(CheckOperandCount(bitcast, 1)); + // Bitcasts are not allowed to change the element type. + if (bitcast->operand(0)->shape().element_type() != + bitcast->shape().element_type()) { + return InternalError( + "Bitcast can not change the element type from %s to %s", + PrimitiveType_Name(bitcast->operand(0)->shape().element_type()), + PrimitiveType_Name(bitcast->shape().element_type())); + } return Status::OK(); } @@ -398,13 +412,11 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { const Shape& operand_shape = broadcast->operand(0)->shape(); // Check for mixed precision. TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape)); - TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == - broadcast->dimensions().size()); - for (int64 operand_dimension = 0; - operand_dimension < ShapeUtil::Rank(operand_shape); + TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size()); + for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank(); ++operand_dimension) { int64 output_dimension = broadcast->dimensions()[operand_dimension]; - TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) && + TF_RET_CHECK((output_dimension < broadcast->shape().rank()) && output_dimension >= 0 && (broadcast->shape().dimensions(output_dimension) == operand_shape.dimensions(operand_dimension))) @@ -481,7 +493,9 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { const Shape& operand_shape_with_layout = custom_call->operand_shapes_with_layout()[i]; TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(), - operand_shape_with_layout)); + operand_shape_with_layout)) + << custom_call->operand(i)->shape().ToString() << " operand " + << operand_shape_with_layout.ToString(); TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); } } @@ -497,21 +511,23 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) { } Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { - TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2)); - return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( - dynamic_slice->operand(0)->shape(), - dynamic_slice->operand(1)->shape(), - dynamic_slice->dynamic_slice_sizes())); + return CheckShape( + dynamic_slice, + ShapeInference::InferDynamicSliceShape( + dynamic_slice->operand(0)->shape(), + Cast(dynamic_slice)->index_shapes(), + dynamic_slice->dynamic_slice_sizes())); } Status ShapeVerifier::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { - TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3)); - return CheckShape(dynamic_update_slice, - ShapeInference::InferDynamicUpdateSliceShape( - dynamic_update_slice->operand(0)->shape(), - dynamic_update_slice->operand(1)->shape(), - dynamic_update_slice->operand(2)->shape())); + return CheckShape( + dynamic_update_slice, + ShapeInference::InferDynamicUpdateSliceShape( + dynamic_update_slice->operand(0)->shape(), + dynamic_update_slice->operand(1)->shape(), + Cast(dynamic_update_slice) + ->index_shapes())); } Status ShapeVerifier::HandleTuple(HloInstruction* tuple) { @@ -523,8 +539,7 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { int64 max_operand_rank = 0; for (const HloInstruction* operand : map->operands()) { operand_shapes.push_back(&operand->shape()); - max_operand_rank = - std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + max_operand_rank = std::max(max_operand_rank, operand->shape().rank()); } // TODO(b/65689298) Remove code below once Map is generalized to accept // arbitrary map dimensions. @@ -683,7 +698,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kConstant: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kFusion: @@ -694,7 +709,6 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReducePrecision: - case HloOpcode::kSelect: case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: @@ -982,7 +996,7 @@ bool ShapeContainsToken(const Shape& shape) { bool contains_token = false; ShapeUtil::ForEachSubshape( shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsToken(subshape)) { + if (subshape.IsToken()) { contains_token = true; } }); @@ -1270,11 +1284,11 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I // or ComputationLowerer::Visit() TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(broadcast->operand(0)->shape())) + broadcast->operand(0)->shape().rank()) << "Broadcast HLO (" << broadcast->ToShortString() << ") has invalid number of dimensions: " << broadcast->dimensions().size() - << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape()); + << " != " << broadcast->operand(0)->shape().rank(); return Status::OK(); } @@ -1324,7 +1338,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { } Status HandleGetTupleElement(HloInstruction* gte) override { - TF_RET_CHECK(ShapeUtil::IsTuple(gte->operand(0)->shape())); + TF_RET_CHECK(gte->operand(0)->shape().IsTuple()); return Status::OK(); } @@ -1344,7 +1358,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleCrossReplicaSum(HloInstruction* crs) override { + Status HandleAllReduce(HloInstruction* crs) override { if (crs->all_reduce_id().has_value()) { TF_RET_CHECK(crs->all_reduce_id().value() > 0) << "All reduce id must be greater than 0 for " @@ -1375,7 +1389,7 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); if (LayoutUtil::IsDenseArray(operand_shape) && - ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) { + operand_shape.rank() == result_shape.rank()) { const Layout& operand_layout = operand_shape.layout(); TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) << "Instruction shouldn't change layouts " diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index e4d0c3d6957885f1d719fedb5a900de601e397f8..facb76a124a4166e2a29c34f01194c9ebb62498b 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -52,9 +52,10 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; + Status HandleReplicaId(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; @@ -168,8 +169,13 @@ class ShapeVerifier : public DfsHloVisitor { // An interface used to encapsulate target-specific verification quirks. class TargetVerifierMetadata { public: + TargetVerifierMetadata(std::function shape_size_function) + : shape_size_function_(shape_size_function) {} + // Returns a target-specific shape size. - virtual int64 ShapeSize(const Shape& shape) const = 0; + int64 ShapeSize(const Shape& shape) const { + return shape_size_function_(shape); + } virtual std::unique_ptr GetVerifier() const = 0; @@ -178,20 +184,23 @@ class TargetVerifierMetadata { TargetVerifierMetadata(const TargetVerifierMetadata&) = delete; TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete; + + private: + // Returns a target-specific shape size. + std::function shape_size_function_; }; // The default implementation of TargetVerifierMetadata, used unless the target // needs to override it. class DefaultVerifierMetadata : public TargetVerifierMetadata { public: - DefaultVerifierMetadata(bool layout_sensitive, bool allow_mixed_precision) - : layout_sensitive_(layout_sensitive), + DefaultVerifierMetadata( + bool layout_sensitive, bool allow_mixed_precision, + std::function shape_size_function) + : TargetVerifierMetadata(shape_size_function), + layout_sensitive_(layout_sensitive), allow_mixed_precision_(allow_mixed_precision) {} - int64 ShapeSize(const Shape& shape) const override { - return ShapeUtil::ByteSizeOf(shape); - } - // Creates a ShapeVerifier that checks that shapes match inferred // expectations. This creates a new verifier every time because ShapeVerifier, // being a DfsHloVisitor, is stateful. We want a clean object for each run of @@ -210,11 +219,14 @@ class DefaultVerifierMetadata : public TargetVerifierMetadata { // the module. class HloVerifier : public HloModulePass { public: - explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision, - std::function - instruction_can_change_layout_func = {}) + explicit HloVerifier( + bool layout_sensitive, bool allow_mixed_precision, + std::function + instruction_can_change_layout_func = {}, + std::function shape_size_func = + [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); }) : target_metadata_(absl::make_unique( - layout_sensitive, allow_mixed_precision)), + layout_sensitive, allow_mixed_precision, shape_size_func)), instruction_can_change_layout_func_( std::move(instruction_can_change_layout_func)) { CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 4bc557e4e62e7df4e25fda86fe417e84129b464c..4f69bd155b8713041ba539098808125956e86259 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -27,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -386,6 +388,55 @@ TEST_F(HloVerifierTest, AddWithLayoutChange) { ASSERT_TRUE(status.ok()); } +TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) { + const char* const kScalarIndexDynamicSlice = R"( + HloModule DynamicSlice_module + + ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] { + %original_parameter = s32[2,2,258] parameter(0) + %constant = s32[] constant(0) + %start_index = s32[] parameter(1) + ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258} + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kScalarIndexDynamicSlice, config)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) { + const char* const kScalarIndexDynamicSlice = R"( + HloModule DynamicUpdateSlice_module + + ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] { + %input = s32[1,1,25,1]{3,2,1,0} parameter(0) + %update = s32[1,1,2,1]{3,2,1,0} parameter(1) + %start_index.0 = s32[] parameter(2) + %start_index.1 = s32[] parameter(3) + %start_index.2 = s32[] parameter(4) + %start_index.3 = s32[] parameter(5) + ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3) + } + )"; + + HloModuleConfig config; + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_allow_scalar_index_dynamic_ops(true); + config.set_debug_options(debug_options); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kScalarIndexDynamicSlice, config)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); auto status = verifier().Run(module.get()).status(); @@ -399,8 +450,9 @@ TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) { HloModule SliceWithLayoutChange ENTRY SliceWithLayoutChange { par0 = f32[4,5]{0,1} parameter(0) - par1 = s32[2] parameter(1) - ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1), + par1 = s32[] parameter(1) + par2 = s32[] parameter(2) + ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2), dynamic_slice_sizes={3,4} } )"; @@ -429,5 +481,76 @@ TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { EXPECT_THAT(status.error_message(), HasSubstr("Instruction shouldn't change layouts")); } + +TEST_F(HloVerifierTest, BitcastCanNotChangeElementType) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY BitcastCanNotChangeElementType { + constant.0 = f32[2] constant({0.0, 0.0}) + ROOT bitcast = s32[2] bitcast(constant.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Bitcast can not change the element type")); +} + +TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY SelectMixedPrecisionNotAllowed { + p0 = pred[] parameter(0) + p1 = f32[32] parameter(1) + p2 = bf16[32] parameter(2) + ROOT select = f32[32] select(p0, p1, p2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Seen floating point types of different precisions")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY SelectMixedPrecisionAllowed { + p0 = pred[] parameter(0) + p1 = f32[32] parameter(1) + p2 = bf16[32] parameter(2) + ROOT select = f32[32] select(p0, p1, p2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, IotaNonArrayResult) { + const char* const hlo_string = R"( + HloModule IotaTupleResult + + ENTRY kernelEntry { + ROOT iota = () iota(), iota_dimension=24 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("does not support non-array result")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index 90904ac00110457bcc3b8974816a7080c4ab89fc..88fc62bd1e2a7830b3f61738a8642308ef4225a7 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -128,9 +128,9 @@ string HumanReadableProfileBuilder::ToString() const { // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); - std::sort( - sorted_ops.begin(), sorted_ops.end(), - [](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; }); + absl::c_sort(sorted_ops, [](const OpInfo& a, const OpInfo& b) { + return a.cycles > b.cycles; + }); for (const auto& op : sorted_ops) { print_op(op); } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 1ebb3319779c00fd4afe90606bf336e16349429d..76bf48870d55e82497ba5f63e9e2e2a322cb330e 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -103,7 +103,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( do { const HloInstruction* instr = stack.back(); - if (cache_.count(instr)) { + if (cache_.contains(instr)) { stack.pop_back(); continue; } @@ -111,9 +111,9 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( switch (FindOrDie(dfs_state_map, instr)) { case kDiscovered: { for (const HloInstruction* operand : instr->operands()) { - if (!cache_.count(operand)) { + if (!cache_.contains(operand)) { stack.push_back(operand); - CHECK(!dfs_state_map.count(operand) || + CHECK(!dfs_state_map.contains(operand) || dfs_state_map[operand] == kDiscovered); dfs_state_map[operand] = kDiscovered; } @@ -1002,7 +1002,7 @@ bool CanFoldDotIntoIndexedArray( absl::Span contracting_dims, absl::Span batch_dims) { absl::optional non_contracting_non_batch_dim = - GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), + GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(), contracting_dims, batch_dims); if (!non_contracting_non_batch_dim.has_value()) { VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions"; @@ -1015,7 +1015,7 @@ bool CanFoldDotIntoIndexedArray( return false; } - int64 indexed_array_rank = ShapeUtil::Rank(indexed_array->shape()); + int64 indexed_array_rank = indexed_array->shape().rank(); if (indexed_array->source_dim() < (indexed_array_rank - 2)) { // This restriction can be lifted by inserting reshape nodes. VLOG(3) << tag @@ -1043,7 +1043,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( return nullptr; } - int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + int64 lhs_rank = lhs->shape().rank(); DotDimensionNumbers new_dim_numbers = dim_numbers; new_dim_numbers.set_lhs_contracting_dimensions( 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); @@ -1078,7 +1078,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( return nullptr; } - int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + int64 rhs_rank = rhs->shape().rank(); DotDimensionNumbers new_dim_numbers = dim_numbers; new_dim_numbers.set_rhs_contracting_dimensions( diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 98246d5403e4aebc2f4d81e52145706355ddd9a9..62107b5a88d4e37552fa5a6384700a9291a9c655 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "absl/strings/ascii.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -43,7 +42,7 @@ class IndexedArrayAnalysisTest : public HloTestBase { string result; for (char c : text) { - if (!isspace(c)) { + if (!absl::ascii_isspace(c)) { result.push_back(c); } else if (!result.empty() && result.back() != ' ') { result.push_back(' '); @@ -99,7 +98,7 @@ TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5] parameter(0) ROOT gather = s32[5,3] gather(operand, indices), offset_dims={1}, @@ -119,7 +118,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5,2] parameter(0) ROOT gather = s32[5] gather(operand, indices), offset_dims={}, @@ -195,7 +194,7 @@ TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { HloModule SimpleGather ENTRY main { - operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + operand = s32[3,3] constant({{1,2,3},{1,2,3},{1,2,3}}) indices_a = s32[5] parameter(0) indices_b = s32[2] parameter(1) gather_a = s32[5,3] gather(operand, indices_a), @@ -309,7 +308,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5] parameter(0) gather = s32[5,4] gather(operand, indices), offset_dims={1}, @@ -330,7 +329,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,7] parameter(0) gather = s32[5,4,7] gather(operand, indices), offset_dims={1}, @@ -352,7 +351,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,2,6] constant(s32[3,2,6]{ + operand = s32[3,2,6] constant({ {{1,2,3,4,5,6},{1,2,3,4,5,6}}, {{1,2,3,4,5,6},{1,2,3,4,5,6}}, {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) @@ -377,7 +376,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,6] constant(s32[2,6]{ + operand = s32[2,6] constant({ {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), @@ -405,7 +404,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) + operand = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 1, 2, 3 } }) i.0 = s64[1,3]{1,0} parameter(0) g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2}, @@ -438,7 +437,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) { HloModule ReshapeOfGather ENTRY main { - operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) + operand = s32[1,6] constant({{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), offset_dims={1}, @@ -465,7 +464,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) { HloModule ReshapeOfGather ENTRY main { - operand = s32[1,2,6] constant(s32[1,2,6]{{ + operand = s32[1,2,6] constant({{ {1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[1] parameter(0) gather = s32[1,1,6] gather(operand, indices), @@ -496,7 +495,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) { HloModule ReshapeOfGather ENTRY main { - operand = s32[2,6] constant(s32[2,6]{ + operand = s32[2,6] constant({ {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1,5] parameter(0) gather = s32[1,5,6] gather(operand, indices), @@ -527,7 +526,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + operand = s32[3,4] constant({{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6] gather(operand, indices), offset_dims={1}, @@ -556,7 +555,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,5,2] constant(s32[3,5,2]{ + operand = s32[3,5,2] constant({ {{1,2},{3,4},{5,6},{7,8},{9,10}}, {{1,2},{3,4},{5,6},{7,8},{9,10}}, {{1,2},{3,4},{5,6},{7,8},{9,10}}}) @@ -588,7 +587,7 @@ TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) { HloModule ReshapeOfGather ENTRY main { - operand = s32[3,4,1] constant(s32[3,4,1]{ + operand = s32[3,4,1] constant({ {{1},{2},{3},{4}}, {{1},{2},{3},{4}}, {{1},{2},{3},{4}}}) @@ -620,7 +619,7 @@ TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { HloModule UnaryOpOfGather ENTRY main { - operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + operand = f32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) indices = s32[5] parameter(0) gather = f32[5,4] gather(operand, indices), offset_dims={1}, @@ -645,7 +644,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) { HloModule AddBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -673,7 +672,7 @@ TEST_F(IndexedArrayAnalysisTest, HloModule SubtractBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -701,7 +700,7 @@ TEST_F(IndexedArrayAnalysisTest, HloModule SubtractBroadcastedScalarWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant = s32[] constant(5) constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) @@ -728,7 +727,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) { HloModule AddBroadcastedVectorWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant_vect = s32[4] constant({10,11,12,13}) constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} indices = s32[5] parameter(0) @@ -755,7 +754,7 @@ TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) { HloModule AddBroadcastedVectorWithGather ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{1,3,2,4},{4,3,2,1}}) constant_vect = s32[5] constant({10,11,12,13,14}) constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} indices = s32[5] parameter(0) @@ -804,8 +803,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_lhs = s32[5,4] gather(gather_operand, indices), offset_dims={1}, @@ -831,8 +830,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[3,3] constant({{1,2,3},{4,5,6},{7,8,9}}) indices = s32[5] parameter(0) dot_lhs = s32[3,5] gather(gather_operand, indices), offset_dims={0}, @@ -859,8 +858,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[3,5] gather(gather_operand, indices), offset_dims={0}, @@ -888,8 +887,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) { HloModule DotOp ENTRY main { - gather_operand = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) - dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + gather_operand = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + dot_lhs_constant = s32[4,3] constant({{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[5,3] gather(gather_operand, indices), offset_dims={1}, @@ -917,8 +916,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) { HloModule DotOp ENTRY main { - gather_operand = s32[2,3,2] constant(s32[2,3,2]{{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) - dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) + gather_operand = s32[2,3,2] constant({{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) + dot_lhs_constant = s32[2,2,3] constant({{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) indices = s32[4] parameter(0) dot_rhs = s32[2,3,4] gather(gather_operand, indices), offset_dims={0,1}, @@ -948,8 +947,8 @@ TEST_F(IndexedArrayAnalysisTest, DotOpNegative) { HloModule DotOp ENTRY main { - gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) - dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) + gather_operand = s32[3,4] constant({{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[2,3] constant({{1,2,3},{4,5,6}}) indices = s32[2] parameter(0) dot_lhs = s32[3,2] gather(gather_operand, indices), offset_dims={0}, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 2297edcbe1d167f0752423f76b795b3592e85c47..d4794acb2f463c4cf8ce5e969f221d52e3742453 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" @@ -94,6 +95,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kPad: case HloOpcode::kReal: case HloOpcode::kReducePrecision: + case HloOpcode::kReplicaId: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: @@ -126,7 +128,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kConvolution: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kCustomCall: @@ -173,23 +175,22 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { ShapeUtil::ForEachSubshape( hlo->shape(), [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape)); } }); - return std::count_if(hlo->operands().begin(), hlo->operands().end(), - [output_rank](HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kBroadcast || - operand->opcode() == HloOpcode::kIota) { - return false; - } - if (operand->opcode() == HloOpcode::kConstant && - ShapeUtil::IsEffectiveScalar(operand->shape())) { - return false; - } - return ShapeUtil::TrueRank(operand->shape()) >= - output_rank; - }) <= 1; + return absl::c_count_if( + hlo->operands(), [output_rank](HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kIota) { + return false; + } + if (operand->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(operand->shape())) { + return false; + } + return ShapeUtil::TrueRank(operand->shape()) >= output_rank; + }) <= 1; } bool InstructionFusion::CanFuseOnAllPaths( @@ -273,7 +274,7 @@ InstructionFusion::ComputeGloballyUnfusible( ShapeUtil::ForEachSubshape( shape, [&size](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { size += ShapeUtil::ElementsIn(subshape); } }); @@ -408,9 +409,8 @@ class ReversePostOrderFusionQueue : public FusionQueue { } sorted_operand_numbers.push_back(i); } - std::sort( - sorted_operand_numbers.begin(), sorted_operand_numbers.end(), - [&](int64 i, int64 j) { + absl::c_sort( + sorted_operand_numbers, [&](int64 i, int64 j) { // Instructions with higher priority in the queue come first. return ( FindOrDie(post_order_index_, instruction->mutable_operand(i)) > @@ -457,8 +457,13 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_ = computation; reachability_ = HloReachabilityMap::Build(computation_); - HloInstructionSet do_not_duplicate = - ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + HloInstructionSet do_not_duplicate; + // If we allow duplications, we need to compute which instructions we do not + // want to duplicate based on a global analysis of the graph. + if (may_duplicate_) { + do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + } auto fusion_queue = GetFusionQueue(computation_); // Instruction fusion effectively fuses edges in the computation graph @@ -565,19 +570,42 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - auto is_reachable = [&](const HloInstruction* a, const HloInstruction* b) { - // A consumer operand may have been multii-output fused into a parallel - // consumer and thus be missing from the oridinal reachability map. - if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { - reachability_ = HloReachabilityMap::Build(consumer->parent()); + absl::flat_hash_set operands; + for (const HloInstruction* operand : consumer->operands()) { + if (operand == producer) { + continue; + } + + // If the reachability map already contains the producer and the operand of + // the consumer, and the producer can reach the operand, then we know for + // sure MultiOutputFusion would create a cycle. If not, we need to do a DFS + // traversal of the computation to verify that this multioutput fusion would + // not create a cycle. + if (reachability_->IsPresent(producer) && + reachability_->IsPresent(operand) && + reachability_->IsReachable(producer, operand)) { + return true; + } + operands.insert(operand->unique_id()); + } + + // Do a DFS on the producer to see if any of the other consumer operands are + // reachable in the current state of the graph. + std::vector worklist = producer->users(); + absl::flat_hash_set visits; + while (!worklist.empty()) { + const HloInstruction* user = worklist.back(); + worklist.pop_back(); + if (operands.count(user->unique_id()) != 0) { + return true; } - return reachability_->IsReachable(a, b); - }; - return absl::c_any_of(consumer->operands(), - [&](const HloInstruction* consumer_operand) { - return consumer_operand != producer && - is_reachable(producer, consumer_operand); - }); + if (visits.count(user->unique_id()) == 0) { + visits.insert(user->unique_id()); + worklist.insert(worklist.end(), user->users().begin(), + user->users().end()); + } + } + return false; } bool InstructionFusion::ShouldFuse(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 6b483126499fe1e635a7d13cf597ec5d089c5b24..611cfd404d7622f561f0acc86fc9b05e16eea22e 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -259,8 +259,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add = f32[4,3]{1,0} add(p0, p0) abs1 = f32[4,3]{1,0} abs(add) log = f32[4,3]{1,0} log(abs1) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 abs2 = f32[4,3]{1,0} abs(log) ROOT root = f32[4,3]{1,0} subtract(abs2, add) })") @@ -290,8 +290,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { p0 = f32[4,3]{1,0} parameter(0) add1 = f32[4,3]{1,0} add(p0, p0) log = f32[4,3]{1,0} log(p0) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 add2 = f32[4,3]{1,0} add(log, add1) ROOT root = f32[4,3]{1,0} subtract(add1, add2) })") @@ -324,8 +324,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { add1 = f32[4,3]{1,0} add(p0, p0) add2 = f32[4,3]{1,0} add(add1, add1) log = f32[4,3]{1,0} log(add2) - token = token[] after-all() - send = f32[4,3]{1,0} send(log, token), channel_id=0 + token0 = token[] after-all() + send = f32[4,3]{1,0} send(log, token0), channel_id=0 sub1 = f32[4,3]{1,0} subtract(log, add2) sub2 = f32[4,3]{1,0} subtract(add2, add1) ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) @@ -394,6 +394,56 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + +TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + TEST_F(InstructionFusionTest, WideningConvertsAreAlwaysDuplicableIntoConsumers) { auto module = ParseHloString(R"( diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index a981d94a999e3d322986bc2bfd56a5b0b5d175fc..a305c6e8005045f7dbca3b8099a3b8ddebb092af 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -1,12 +1,12 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) - load( "//tensorflow/core:platform/default/build_config_root.bzl", "if_static", ) +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + cc_library( name = "interpreter_transfer_manager", srcs = ["interpreter_transfer_manager.cc"], @@ -34,6 +34,7 @@ cc_library( "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -47,8 +48,10 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/compiler/xla/service:map_inliner", + "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:while_loop_simplifier", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:lib", "//tensorflow/stream_executor", "@com_google_absl//absl/memory", @@ -115,6 +118,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/stream_executor/host:host_stream", + "//tensorflow/stream_executor/host:host_timer", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 3a5177c418e3af8253df228a51f2fc0901d10041..0827b1daf89bebb68c045784ef2b9da677792880 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -21,6 +21,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" @@ -31,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/map_inliner.h" +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -40,12 +43,50 @@ limitations under the License. namespace xla { namespace interpreter { +namespace { + +// Handles custom_call ops during evaluation by routing them through the global +// CPU registry used by other CPU-based backends. +StatusOr HandleEvaluatorCustomCall( + HloInstruction* custom_call, absl::Span operands) { + // Find the target C function in the global registry. + auto* registry = xla::cpu::CustomCallTargetRegistry::Global(); + void* target_fn = registry->Lookup(custom_call->custom_call_target()); + if (!target_fn) { + return NotFound("Custom call target '%s' was not registered", + custom_call->custom_call_target()); + } + + // Populate pointers to operand and output literal data. + std::vector operand_data; + operand_data.reserve(operands.size()); + for (const auto* literal : operands) { + operand_data.push_back(literal->untyped_data()); + } + auto output = Literal::CreateFromShape(custom_call->shape()); + void* output_data = output.untyped_data(); + + // Call the target function matching the C ABI used by the CPU backends. + auto* typed_fn = reinterpret_cast(target_fn); + (*typed_fn)(output_data, operand_data.data()); + + return std::move(output); +} + +} // namespace + Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); + pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), LayoutAssignment::InstructionCanChangeLayout); + + ReducePrecisionInsertion::AddPasses( + &pipeline, hlo_module->config().debug_options(), + ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); + return pipeline.Run(hlo_module).status(); } @@ -75,10 +116,15 @@ StatusOr> InterpreterCompiler::RunBackend( // In this case we are using an HloEvaluator at execution time, so we don't // need to compile anything + auto evaluator = absl::make_unique(); + evaluator->set_use_fast_path( + hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path()); + evaluator->set_custom_call_handler(HandleEvaluatorCustomCall); + // Create executable from only the Hlo module. std::unique_ptr executable = - absl::make_unique( - std::move(hlo_module), absl::make_unique()); + absl::make_unique(std::move(hlo_module), + std::move(evaluator)); return std::move(executable); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 7635fbfed6f6a51fc9d203251d9bebf43cc63fd9..7a6ebdef708bcc3a92fbd8618db0c42c35e6ce8b 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -68,6 +68,18 @@ StatusOr InterpreterExecutable::ExecuteOnStream( "Mismatch between argument count and graph parameter count."); } + // Check that the args have the right shape. + for (int64 i = 0; i < computation->num_parameters(); ++i) { + const auto& expected_shape = computation->parameter_instruction(i)->shape(); + const auto& actual_shape = arguments[i]->on_device_shape(); + if (!ShapeUtil::Equal(expected_shape, actual_shape)) { + return InvalidArgument( + "Shape mismatch on parameter %d. Expected %s, but was %s.", i, + ShapeUtil::HumanString(expected_shape), + ShapeUtil::HumanString(actual_shape)); + } + } + TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, TransferManager::GetForPlatform(platform)); @@ -85,8 +97,9 @@ StatusOr InterpreterExecutable::ExecuteOnStream( Literal result_literal; { tensorflow::mutex_lock lock(evaluator_lock_); - TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( - *computation, arg_literals)); + evaluator_->ResetVisitStates(); + TF_ASSIGN_OR_RETURN(result_literal, + evaluator_->Evaluate(*computation, arg_literals)); } // Transform the result literal back into a ShapedBuffer. @@ -116,7 +129,7 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( } /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) { - if (ShapeUtil::IsOpaque(shape)) { + if (shape.IsOpaque()) { return sizeof(void*); } return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index eddef850cf5250b85b564c1e6c92d1cc8ecd1a43..9376a3c8f8963551a89dcedd77068a39ffd05301 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -147,12 +147,9 @@ bool LayoutConstraints::OperandBufferForwarded( PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction); PointsToSet::BufferSet* operand_buffers = GetBufferSet(instruction->operand(operand_no)); - for (const LogicalBuffer* output_buffer : *output_buffers) { - if (operand_buffers->count(output_buffer) > 0) { - return true; - } - } - return false; + return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) { + return operand_buffers->count(b) > 0; + }); } Status LayoutConstraints::SetBufferLayout(const Layout& layout, @@ -256,7 +253,7 @@ Status LayoutConstraints::SetArrayOperandLayout( const Layout& layout, const HloInstruction* instruction, int64 operand_no, bool mandatory, bool dfs) { const HloInstruction* operand = instruction->operand(operand_no); - TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); + TF_RET_CHECK(operand->shape().IsArray()); Shape shape(operand->shape()); *shape.mutable_layout() = layout; TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); @@ -314,7 +311,7 @@ Status LayoutConstraints::SetInstructionLayout( CHECK_EQ(1, buffers.size()); CHECK_EQ(buffers[0]->instruction(), instruction); - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { return SetBufferLayout(subshape.layout(), *buffers[0], mandatory); } else { return Status::OK(); @@ -406,7 +403,7 @@ Status LayoutAssignment::BuildHostChannelConstraints( instruction->opcode() == HloOpcode::kRecv) { const Shape& data_shape = ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0); - TF_RET_CHECK(ShapeUtil::IsArray(data_shape)); + TF_RET_CHECK(data_shape.IsArray()); TF_RET_CHECK(LayoutUtil::HasLayout(data_shape)); const Layout* prev_layout = host_channel_constraints_.ConstrainChannel( send_recv_instr->channel_id(), data_shape.layout()); @@ -489,7 +486,7 @@ Status LayoutAssignment::AddMandatoryConstraints( if (instruction->opcode() == HloOpcode::kSend) { // TODO(b/68493863): Change to use SetOperandLayout(). const Shape send_buffer_shape = instruction->operand(0)->shape(); - TF_RET_CHECK(ShapeUtil::IsArray(send_buffer_shape)); + TF_RET_CHECK(send_buffer_shape.IsArray()); Shape new_buffer_shape = get_channel_constraints(instruction) ->LayoutShapeForChannel(send_buffer_shape, @@ -499,7 +496,7 @@ Status LayoutAssignment::AddMandatoryConstraints( } else { const Shape recv_buffer_shape = ShapeUtil::GetTupleElementShape(instruction->shape(), 0); - TF_RET_CHECK(ShapeUtil::IsArray(recv_buffer_shape)); + TF_RET_CHECK(recv_buffer_shape.IsArray()); TF_ASSIGN_OR_RETURN( const LogicalBuffer* buffer, constraints->points_to_analysis().GetBufferDefinedAt(instruction, @@ -520,7 +517,7 @@ Status LayoutAssignment::AddMandatoryConstraints( } // TODO(b/68493863): Change to use SetOperandLayout(). const Shape& buffer_shape = instruction->operand(0)->shape(); - TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape)); + TF_RET_CHECK(buffer_shape.IsArray()); Shape new_buffer_shape = get_channel_constraints(instruction) ->LayoutShapeForChannel(buffer_shape, all_reduce_id); @@ -780,7 +777,7 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( << ShapeUtil::HumanString(instruction->shape()) << " instruction: " << instruction->ToString(); - if (ShapeUtil::IsTuple(instruction->shape())) { + if (instruction->shape().IsTuple()) { // Copy tuple elements which have differing layouts. std::vector element_copies; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); @@ -811,7 +808,7 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( shape_with_layout, tuple_copy->mutable_shape())); return tuple_copy; - } else if (ShapeUtil::IsArray(instruction->shape())) { + } else if (instruction->shape().IsArray()) { HloInstruction* copy = instruction->parent()->AddInstruction(HloInstruction::CreateUnary( instruction->shape(), HloOpcode::kCopy, instruction)); @@ -988,11 +985,10 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Layout& output_layout, const HloInstruction* instruction, int64 operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - CHECK(ShapeUtil::IsArray(instruction->shape())); - CHECK(ShapeUtil::IsArray(operand->shape())); + CHECK(instruction->shape().IsArray()); + CHECK(operand->shape().IsArray()); if (!ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == - ShapeUtil::Rank(instruction->shape()) && + operand->shape().rank() == instruction->shape().rank() && !instruction_can_change_layout_func_(instruction)) { // Propagate the result layout to the operand layout if the instruction // requires the same layout out for the result and the operand. @@ -1012,7 +1008,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( // operations. For similar reasons, if the operand and output have the same // rank, try to match the operand's layout to the output. if (ShapeUtil::TrueRank(operand->shape()) == 1 && - ShapeUtil::Rank(instruction->shape()) == 1) { + instruction->shape().rank() == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; } @@ -1026,7 +1022,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { return absl::make_unique(operand_shape.layout()); } - if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { + if (operand_shape.rank() == output_shape.rank()) { *operand_shape.mutable_layout() = output_layout; if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { @@ -1045,7 +1041,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (instruction->opcode() == HloOpcode::kTranspose) { // Pick the operand layout that makes the transpose a bitcast. - int64 rank = ShapeUtil::Rank(instruction->shape()); + int64 rank = instruction->shape().rank(); std::vector new_minor_to_major(rank); for (int64 i = 0; i < rank; ++i) { int64 output_dim = LayoutUtil::Minor(output_layout, i); @@ -1066,11 +1062,10 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( int64 operand_no) { const HloInstruction* operand = user->operand(operand_no); - CHECK(ShapeUtil::IsArray(user->shape()) && - ShapeUtil::IsArray(operand->shape())); + CHECK(user->shape().IsArray() && operand->shape().IsArray()); if (!ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && + operand->shape().rank() == user->shape().rank() && !instruction_can_change_layout_func_(user)) { // Assign users the same layout as the operand. return absl::make_unique(operand_layout); @@ -1083,7 +1078,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( // reshape is a bitcast when using the same layout. This may avoid copy // operations. For similar reasons, if the operand and output have the same // rank, try to match the outputs's layout to the operand. - if (ShapeUtil::Rank(operand->shape()) == 1 && + if (operand->shape().rank() == 1 && ShapeUtil::TrueRank(user->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; @@ -1098,7 +1093,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { return absl::make_unique(output_shape.layout()); } - if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { + if (operand->shape().rank() == output_shape.rank()) { *output_shape.mutable_layout() = operand_layout; if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { @@ -1117,7 +1112,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (user->opcode() == HloOpcode::kTranspose) { // Pick the user layout that makes the transpose a bitcast. - int64 rank = ShapeUtil::Rank(user->shape()); + int64 rank = user->shape().rank(); std::vector new_minor_to_major(rank); auto inverse_dimensions = InversePermutation(user->dimensions()); for (int64 i = 0; i < rank; ++i) { @@ -1193,7 +1188,7 @@ std::vector> GetArrayUsesOfBuffer( CHECK(buffer.IsArray()); std::vector> uses; for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) { - if (!ShapeUtil::IsArray(buffer_alias.instruction()->shape())) { + if (!buffer_alias.instruction()->shape().IsArray()) { continue; } // This alias must be the top-level (index == {}) of the instruction's @@ -1227,7 +1222,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) { for (const LogicalBuffer* buffer : buffers) { if (constraints->BufferLayout(*buffer) == nullptr && - ShapeUtil::IsArray(buffer->shape())) { + buffer->shape().IsArray()) { TF_RETURN_IF_ERROR(constraints->SetBufferLayout( ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(), *buffer, /*mandatory=*/true)); @@ -1238,6 +1233,23 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( }); } +namespace { +// A transpose or a reshape that only changes trivial dimensions have meaningful +// layouts that are valuable to propagate in a depthfirst manner to avoid +// unassigned layouts in the graph. +bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) { + switch (hlo.opcode()) { + case HloOpcode::kReshape: + return std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()); + case HloOpcode::kTranspose: + return true; + default: + return false; + } +} + +} // namespace + Status LayoutAssignment::PropagateOperandConstraint( const OperandLayoutConstraint& operand_constraint, LayoutConstraints* constraints) { @@ -1258,11 +1270,10 @@ Status LayoutAssignment::PropagateOperandConstraint( // layout for the operands with the same ranks. const HloInstruction* operand = operand_constraint.operand(); const HloInstruction* user = operand_constraint.instruction(); - if (!ShapeUtil::IsArray(operand->shape())) { + if (!operand->shape().IsArray()) { return Status::OK(); } - if (instruction_can_change_layout_func_(user) && - !ShapeUtil::IsArray(user->shape())) { + if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) { return Status::OK(); } @@ -1273,7 +1284,7 @@ Status LayoutAssignment::PropagateOperandConstraint( return Status::OK(); } - int64 operand_rank = ShapeUtil::Rank(operand->shape()); + int64 operand_rank = operand->shape().rank(); if (operand_rank <= 1) { return Status::OK(); } @@ -1288,7 +1299,7 @@ Status LayoutAssignment::PropagateOperandConstraint( continue; } const HloInstruction* sibling = user->operand(operand_no); - const int64 sibling_rank = ShapeUtil::Rank(sibling->shape()); + const int64 sibling_rank = sibling->shape().rank(); if (sibling_rank <= 1) { continue; } @@ -1317,16 +1328,16 @@ Status LayoutAssignment::PropagateOperandConstraint( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { return Status::OK(); } - if (ShapeUtil::Rank(subshape) <= 1) { + if (subshape.rank() <= 1) { return Status::OK(); } // Assign the right layout to input fusion of higher rank reduce // operations. - if (ShapeUtil::Rank(subshape) != ShapeUtil::Rank(operand->shape())) { + if (subshape.rank() != operand->shape().rank()) { return Status::OK(); } // TODO(b/67641796): Are there cases except fusion that use this code @@ -1354,10 +1365,10 @@ Status LayoutAssignment::PropagateOperandConstraint( } TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { return Status::OK(); } - if (ShapeUtil::Rank(subshape) <= 1) { + if (subshape.rank() <= 1) { return Status::OK(); } TF_ASSIGN_OR_RETURN( @@ -1373,7 +1384,7 @@ Status LayoutAssignment::PropagateOperandConstraint( TF_RETURN_IF_ERROR(constraints->SetBufferLayout( *layout, *buffer, /*mandatory=*/user->opcode() == HloOpcode::kReduce, - /*dfs=*/false)); + /*dfs=*/InstructionShouldPropagateDepthFirst(*user))); } } return Status::OK(); @@ -1401,8 +1412,8 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( } if (!instruction_can_change_layout_func_(instruction)) { // Copy the layout to the operand. - if (buffer.IsArray() && ShapeUtil::IsArray(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == + if (buffer.IsArray() && operand->shape().IsArray() && + operand->shape().rank() == LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) { TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( buffer_constraint.layout(), instruction, operand_no, @@ -1410,7 +1421,7 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( } } else { if (!buffer.IsTopLevel() || - !ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + !instruction->operand(operand_no)->shape().IsArray()) { continue; // Don't touch buffers that are internal to a tuple. } VLOG(6) << "Propagating constraint to operand " << operand_no << " of " @@ -1423,11 +1434,9 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(), instruction, operand_no); if (operand_layout != nullptr) { - // Do not propagate operand constraints of transposes and reshapes, it - // tends to create really bad layouts. TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( *operand_layout, instruction, operand_no, /*mandatory=*/false, - /*dfs=*/false)); + /*dfs=*/InstructionShouldPropagateDepthFirst(*instruction))); } } else { VLOG(6) << "Operand already has a constraint " @@ -1497,7 +1506,7 @@ StatusOr InferArrayLayout( // This function should only be called for array shapes which don't yet have // layouts. const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index); - TF_RET_CHECK(ShapeUtil::IsArray(subshape)); + TF_RET_CHECK(subshape.IsArray()); TF_RET_CHECK(!subshape.has_layout()); // The instruction should not define the buffer at this index. @@ -1576,8 +1585,9 @@ Status SetFusionLayouts(HloInstruction* fusion) { fused_instruction->mutable_shape())); } else if (fused_instruction->opcode() == HloOpcode::kInfeed) { // Nop; leave the infeed layout alone. - } else { + } else if (fusion->fusion_kind() != HloInstruction::FusionKind::kCustom) { // Other instructions don't have layouts inside of fusion nodes. + // But do not clear layouts for other instructions in custom fusion nodes. LayoutUtil::ClearLayout(fused_instruction->mutable_shape()); } } @@ -1615,7 +1625,7 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, for (const LogicalBuffer* buffer : constraints.points_to_analysis().GetBuffersDefinedByInstruction( instruction)) { - if (!ShapeUtil::IsArray(buffer->shape())) { + if (!buffer->shape().IsArray()) { continue; } @@ -1639,7 +1649,7 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus( instruction->mutable_shape(), [instruction, &constraints](Shape* subshape, const ShapeIndex& index) { - if (subshape->has_layout() || !ShapeUtil::IsArray(*subshape)) { + if (subshape->has_layout() || !subshape->IsArray()) { return Status::OK(); } // Set Layout of subshape to match layout of LogicalBuffer which @@ -2012,7 +2022,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kConditional: case HloOpcode::kConvert: case HloOpcode::kCos: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kDivide: @@ -2085,6 +2095,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReduce: + case HloOpcode::kReplicaId: case HloOpcode::kReshape: case HloOpcode::kRng: case HloOpcode::kSend: @@ -2100,8 +2111,8 @@ bool LayoutAssignment::InstructionCanChangeLayout( /* static */ bool LayoutAssignment::IsAtMostRank1(const Shape& shape) { - if (ShapeUtil::IsArray(shape)) { - return ShapeUtil::Rank(shape) <= 1; + if (shape.IsArray()) { + return shape.rank() <= 1; } return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) { return IsAtMostRank1(subshape); @@ -2123,7 +2134,7 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kCopy && - added_copies_.count(instruction) > 0) { + added_copies_.contains(instruction)) { VLOG(5) << "Removing added copy: " << instruction->ToString(); TF_RETURN_IF_ERROR( instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 3b081de3c7826c3c11a7d87d542835d0ecce1b7e..5701cb5b025e563247d46d0d24f81a5f886fc23b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -243,7 +243,7 @@ class ChannelLayoutConstraints { // Returns true if channel_id has a layout constraint. bool IsChannelConstrained(int64 channel_id) const { - return constraints_.count(channel_id) > 0; + return constraints_.contains(channel_id); } // Given `shape`, apply the layout for `channel_id`. `channel_id` must already @@ -276,7 +276,7 @@ class ChannelLayoutConstraints { } private: - std::unordered_map constraints_; + absl::flat_hash_map constraints_; }; // HLO pass which assigns layouts to all instructions in the HLO module while diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 311bd7890545b5b2cbec920d2d12ddd482d0d53c..c8cf3c47d380012fdb0206c0d20d67e6a13017ae 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -529,8 +528,7 @@ class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { for (int64 operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - if (ShapeUtil::Rank(instruction->shape()) != - ShapeUtil::Rank(operand->shape())) { + if (instruction->shape().rank() != operand->shape().rank()) { continue; } TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( @@ -848,12 +846,12 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ENTRY entry_computation { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 - token = token[] after-all() - recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1} + token0 = token[] after-all() + recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1} recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1, sharding={maximal device=1} ROOT root = f32[2,2] get-tuple-element(recv-done), index=0 - send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1, + send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1, sharding={maximal device=0} send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0} } @@ -895,11 +893,11 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { ENTRY entry_computation { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 - ar.0 = f32[2,2] cross-replica-sum(gte), + ar.0 = f32[2,2] all-reduce(gte), all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=0} - const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) - ROOT ar.1 = f32[2,2] cross-replica-sum(const), + const = f32[2,2] constant({{0,1},{2,3}}) + ROOT ar.1 = f32[2,2] all-reduce(const), all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=1} })"; @@ -962,8 +960,9 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange { par0 = f32[3,4]{1,0} parameter(0) par1 = f32[4,5]{0,1} parameter(1) - par2 = s32[2] parameter(2) - dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4} + par2 = s32[] parameter(2) + par3 = s32[] parameter(3) + dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4} ROOT add0 = f32[3,4]{1,0} add(par0,dslice0) } )"; @@ -984,7 +983,7 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { m::Parameter(), m::DynamicSlice( m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy), - m::Parameter(2))))); + m::Parameter(2), m::Parameter(3))))); } TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 728a66b388f0f9af480ff88b5e96990a26e36af5..c5d59fb28e02ce229967fb3856012d608fb83c5d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -39,7 +39,6 @@ cc_library( "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm//:core", ], @@ -169,6 +168,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm//:core", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 643ecd0fbaa546c551097b29e74ccd49418e1466..ce3d922ca7a9bdea3a520959a8b8d284bc3e0d64 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -81,9 +81,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, if (hlo.opcode() == HloOpcode::kParameter) { const std::vector& parameter_instructions = module_.entry_computation()->parameter_instructions(); - if (std::find(parameter_instructions.begin(), - parameter_instructions.end(), - &hlo) != parameter_instructions.end()) { + if (absl::c_linear_search(parameter_instructions, &hlo)) { array->MarkInvariantOverWholeProgram(context_); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 2b46b3c3964b15548dbacc8b0ada0047a0fa85b6..12e2f449e23ac2511aac576fed893f5a9ef510c0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -76,15 +76,12 @@ class AliasAnalysis { // A map from a buffer slice to metadata corresponding to its alias.scope // metadata. The index kParameterAliasSet is used to hold aliasing // information for parameters. - absl::flat_hash_map + absl::flat_hash_map alias_scope_metadata_; // A map from a buffer slice to metadata corresponding to its noalias // metadata. - absl::flat_hash_map - noalias_metadata_; + absl::flat_hash_map noalias_metadata_; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc index bdce4a171b8a58f617f1d56e6cf6db5354846703..1ea5a42b0b398818b0946eaa9e214100007bada4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -41,14 +41,26 @@ static const HloInstruction& InstrForConstantBufferAllocation( return *const_instr; } -string ConstantBufferAllocationToGlobalName( - const BufferAllocation& allocation) { - string instr_name = InstrForConstantBufferAllocation(allocation).name(); +string SanitizeConstantName(const HloInstruction& instr) { + CHECK_EQ(instr.opcode(), HloOpcode::kConstant); + string instr_name = instr.name(); for (char& c : instr_name) { - if (c == '.') { + // Having a hyphen or a dot in a global variable name can crash the LLVM PTX + // backend. + if (c == '.' || c == '-') { c = '_'; } } + return instr_name; +} + +string ConstantBufferAllocationToGlobalName( + const BufferAllocation& allocation) { + const HloInstruction& instr = InstrForConstantBufferAllocation(allocation); + string instr_name = instr.name(); + // Check that names are sanitized and stored in the HLO instructions + // before constant buffer allocation. + DCHECK_EQ(instr_name, SanitizeConstantName(instr)); return absl::StrCat("buffer_for_", instr_name); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h index bfb6eecb87f6a1b756b3a8da3377f608dd7f0be7..03e98a66900095889292cbff9d9924a9abe83ab0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h @@ -20,6 +20,10 @@ limitations under the License. namespace xla { namespace llvm_ir { +// Sanitizes the HLO constant instruction name so that it can be used for the +// name of the corresponding constant buffer. In particular, it replaces . and +// - with _. +string SanitizeConstantName(const HloInstruction& instr); // In XLA:GPU we map constant buffer allocations to globals in the generated // LLVM IR. This function gives us the name of the global variable a constant // buffer is mapped to. Not used on XLA:CPU. diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 4d7f36d9f8b565a819edf0631efc5c7a58c4f87f..c66eaec8fb0e4c03f6967fec0cf0ae9661cdf470 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -36,19 +36,20 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, // EmitFusedDynamicUpdateSliceInPlace. // // Emits a sequential loop if launch_dimensions is null. +using IndexGenerator = std::function(int64)>; + static Status EmitDynamicUpdateSliceInPlaceImpl( - const Shape& update_shape, const ElementGenerator& start_indices_generator, + const Shape& update_shape, const IndexGenerator& start_indices_generator, bool is_signed, ElementGenerator update_array_generator, const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, absl::string_view name, llvm::IRBuilder<>* b) { const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. - const int64 rank = ShapeUtil::Rank(output_shape); + const int64 rank = output_shape.rank(); IrArray::Index start_index(b->getInt64Ty(), rank); for (int64 i = 0; i < rank; ++i) { - IrArray::Index dim_index({b->getInt64(i)}); - TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); + TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(i)); llvm::Value* output_dim_size = llvm::ConstantInt::get( start_index[i]->getType(), output_shape.dimensions(i)); llvm::Value* update_dim_size = llvm::ConstantInt::get( @@ -112,9 +113,20 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, Shape output_shape = output_array.GetShape(); Shape update_shape = update_array.GetShape(); - ElementGenerator start_indices_generator = [&](const IrArray::Index& index) { - return start_indices_array.EmitReadArrayElement(index, b); - }; + IndexGenerator start_indices_generator; + // TODO(b/118437727): Remove the R1 path, and rename the variables. + if (start_indices_array.GetShape().rank() == 1) { + start_indices_generator = [&](int64 index) { + return start_indices_array.EmitReadArrayElement( + IrArray::Index({b->getInt64(index)}), b); + }; + } else { + start_indices_generator = [&](int64 index) { + return operand_arrays[2 + index].EmitReadArrayElement( + IrArray::Index(b->getInt64Ty()), b); + }; + } + ElementGenerator update_array_generator = [&](const IrArray::Index& index) { return update_array.EmitReadArrayElement(index, b); }; @@ -165,8 +177,21 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( elemental_emitter); TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); - ElementGenerator start_indices_generator = - fused_emitter.GetGenerator(start_indices); + + // TODO(b/118437727): Remove the R1 path, and rename the variables. + IndexGenerator start_indices_generator; + if (start_indices->shape().rank() == 1) { + start_indices_generator = [&](int64 index) { + return fused_emitter.GetGenerator(start_indices)( + IrArray::Index({b->getInt64(index)})); + }; + } else { + start_indices_generator = [&](int64 index) { + ElementGenerator element_generator = + fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index)); + return element_generator(IrArray::Index(b->getInt64Ty())); + }; + } bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape()); return EmitDynamicUpdateSliceInPlaceImpl( diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 38f2b5da23a7b92e4547dceaba011ce654977da3..e440f05e2b2f0d4a2a4c7b326b4881183de4d235 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -35,7 +35,7 @@ using llvm_ir::IrArray; Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { indexed_generators_[hlo] = [=](const IrArray::Index& index) -> StatusOr { - if (generated_value_cache_[hlo].count(index.multidim()) > 0) { + if (generated_value_cache_[hlo].contains(index.multidim())) { llvm::Value* generated_value = generated_value_cache_[hlo][index.multidim()]; llvm::BasicBlock* generated_value_bb = nullptr; @@ -115,7 +115,7 @@ Status FusedIrEmitter::HandleGetTupleElement( /*alignment=*/1, tuple_ptr, b_, module_); }; - if (!ShapeUtil::IsTuple(get_tuple_element->shape())) { + if (!get_tuple_element->shape().IsTuple()) { indexed_generators_[get_tuple_element] = [=](const IrArray::Index& index) -> StatusOr { // TODO(b/34080002) Add aliasing information to tuple element IrArray. diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 1b9c61f6700e2a1309b21e499f4a9e2439ed3702..e6d52a580c04a920d3f0e8ed6f39c1cae587cf1b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" @@ -134,8 +135,9 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { // Cache of generated values, lest we regenerate an element of a node with // multiple outgoing edges - std::unordered_map, llvm::Value*>> + absl::flat_hash_map< + const HloInstruction*, + absl::flat_hash_map, llvm::Value*>> generated_value_cache_; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 67f7423121177e2ca1e3384341dad2644c8f5e34..8ee07ae8331e986f9d271be5e39065f0d87853b1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -61,7 +61,7 @@ void IrArray::Index::Delinearize(std::vector* multidim, IrArray::Index::Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b) - : multidim_(ShapeUtil::Rank(shape)), + : multidim_(shape.rank()), linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { @@ -104,8 +104,8 @@ IrArray::Index::Index(absl::Span multidim, CHECK(LayoutUtil::HasLayout(shape)); } -IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) - : base_ptr_(base_ptr), shape_(&shape) { +IrArray::IrArray(llvm::Value* base_ptr, Shape shape) + : base_ptr_(base_ptr), shape_(std::move(shape)) { TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); CHECK(base_ptr_->getType()->isPointerTy()); int depth = 0; @@ -117,10 +117,10 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) ++depth; } - if (!ShapeUtil::IsArray(*shape_) || ShapeUtil::IsScalar(*shape_)) { + if (!shape_->IsArray() || ShapeUtil::IsScalar(*shape_)) { DCHECK(depth == 1 || depth == 0) << depth; } else { - DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString(); + DCHECK_EQ(depth, shape_->rank()) << shape.ShortDebugString(); } } @@ -137,12 +137,12 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( const Shape& output_shape, const Shape& input_shape, llvm::IRBuilder<>* builder) const { const auto& target_index = *this; - CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape)); + CHECK_EQ(target_index.size(), output_shape.rank()); std::vector> common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); std::vector source_multidim_index( - ShapeUtil::Rank(input_shape), llvm::UndefValue::get(index_type_)); + input_shape.rank(), llvm::UndefValue::get(index_type_)); // We compute the source indices in each common factor from only the target // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { @@ -257,7 +257,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( const Shape& shape, const Shape& operand_shape, absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { - int64 rank = ShapeUtil::Rank(operand_shape); + int64 rank = operand_shape.rank(); std::vector source_index(rank); for (int64 i = 0; i < rank; ++i) { source_index[i] = multidim_[dimension_mapping[i]]; @@ -271,7 +271,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( // The other dimensions can be masked out with a div and a mod operation. std::vector logical_to_physical = LayoutUtil::MakeLogicalToPhysical(shape.layout()); - int64 output_rank = ShapeUtil::Rank(shape); + int64 output_rank = shape.rank(); // The minimum physical dimension that is broadcasted. int64 min_broadcasted_dimension = output_rank; // The maximum physical dimension that is broadcasted. @@ -348,7 +348,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, // over higher-rank arrays. return base_ptr_; } - CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); + CHECK_EQ(index.size(), shape_->rank()); if (index.LinearValidOnShape(*shape_)) { llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index d6d84994ee147f4b8c1a333b0eaccdf6e0a2219b..b706ebd311cbb706e7e4698b93319e37e664d10a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -130,6 +130,11 @@ class IrArray { CHECK_LE(index, size()); mutable_multidim().insert(mutable_multidim().begin() + index, value); } + void InsertAt(int64 index, int64 count, llvm::Value* value) { + CHECK_LE(index, size()); + mutable_multidim().insert(mutable_multidim().begin() + index, count, + value); + } using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; @@ -189,6 +194,8 @@ class IrArray { return llvm::ConstantInt::get(index_type_, c); } + void ClearLinearIndex() { linear_ = nullptr; } + private: // Changing the multi-dimensional index invalidates the linear index. std::vector& mutable_multidim() { @@ -220,11 +227,11 @@ class IrArray { }; // Default constructor. Constructs an IrArray in a null status. - IrArray() : base_ptr_(nullptr), shape_(nullptr) {} + IrArray() : base_ptr_(nullptr) {} // Construct an IrArray with the given base pointer and shape. base_ptr is a // pointer type pointing to the first element(lowest address) of the array. - IrArray(llvm::Value* base_ptr, const Shape& shape); + IrArray(llvm::Value* base_ptr, Shape shape); // Default implementations of copying and moving. IrArray(IrArray&& other) = default; @@ -236,7 +243,6 @@ class IrArray { llvm::Type* GetElementLlvmType() const { return element_type_; } const Shape& GetShape() const { - CHECK(shape_ != nullptr); return *shape_; } @@ -331,7 +337,7 @@ class IrArray { llvm::Type* element_type_; // Shape of the XLA array. - const Shape* shape_; + absl::optional shape_; // The list of key/value pairs used when attaching metadata to emitted // loads/stores for this array. They keys are the metadata kinds and the diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h index abc06fb7b4245294df2dc20d25a22ac4fdaeb4cf..cf5083e8c13b9485035923895cec1ad05049c644 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -254,6 +254,11 @@ class IrBuilderMixin { return mixin_builder()->CreateFCmpOLT(std::forward(args)...); } + template + llvm::Value* FCmpOLE(Args&&... args) { + return mixin_builder()->CreateFCmpOLE(std::forward(args)...); + } + template llvm::Value* FCmpONE(Args&&... args) { return mixin_builder()->CreateFCmpONE(std::forward(args)...); diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index bd0139f85b6a5c5dc23dad962263038451921e65..5eeb29c478a371dae83251771f2dc4844672d3e9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -18,28 +18,29 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { -Status KernelSupportLibrary::For( +Status KernelSupportLibrary::ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return If(b_->CreateICmpSLT(start, end), [&]() -> Status { + return IfWithStatus(b_->CreateICmpSLT(start, end), [&]() -> Status { TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); - return For(name, b_->CreateAdd(start, step), end, step, - [&](llvm::Value* iv) { return for_body_generator(iv, false); }); + return ForWithStatus( + name, b_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } -Status KernelSupportLibrary::For( +Status KernelSupportLibrary::ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& for_body_generator) { if (peel_first_iteration) { - return For(name, start, end, step, true, - [&](llvm::Value* indvar, bool is_first_iteration) -> Status { - return for_body_generator(indvar, - b_->getInt1(is_first_iteration)); - }); + return ForWithStatus( + name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) -> Status { + return for_body_generator(indvar, b_->getInt1(is_first_iteration)); + }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( name, start, end, step, b_, @@ -55,7 +56,7 @@ Status KernelSupportLibrary::For( } } -Status KernelSupportLibrary::If( +Status KernelSupportLibrary::IfWithStatus( absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 43fec311f150d6054f6ad24f99db332f90ff94a3..612b839cfa15711061e1ae53358a72d5220e1801 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -48,41 +48,42 @@ class KernelSupportLibrary { // for (i64 i = `start` + `step`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator); - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { CHECK_EQ(Status::OK(), - For(name, start, end, step, + ForWithStatus( + name, start, end, step, [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { for_body_generator(ind_var, is_first_iteration); return Status::OK(); })); } - Status For(absl::string_view name, int64 start, int64 end, int64 step, - const std::function& - for_body_generator) { - return For(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + Status ForWithStatus( + absl::string_view name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + return ForWithStatus(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } - void ForReturnVoid( + void For( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -99,19 +100,19 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, - llvm::Value* step, bool peel_first_iteration, - const std::function& - for_body_generator); + Status ForWithStatus( + absl::string_view name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); - void ForReturnVoid(absl::string_view name, llvm::Value* start, - llvm::Value* end, llvm::Value* step, - bool peel_first_iteration, - const std::function& - for_body_generator) { - TF_CHECK_OK(For( + void For(absl::string_view name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator) { + TF_CHECK_OK(ForWithStatus( name, start, end, step, peel_first_iteration, [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { for_body_generator(ind_var, is_first_iteration); @@ -119,80 +120,81 @@ class KernelSupportLibrary { })); } - Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, - int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - return For(name, /*start=*/start, /*end=*/end, - /*step=*/llvm::ConstantInt::get(start->getType(), step), - peel_first_iteration, for_body_generator); + Status ForWithStatus( + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, + bool peel_first_iteration, + const std::function& + for_body_generator) { + return ForWithStatus( + name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } - void ForReturnVoid(absl::string_view name, llvm::Value* start, - llvm::Value* end, int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - ForReturnVoid(name, /*start=*/start, /*end=*/end, - /*step=*/llvm::ConstantInt::get(start->getType(), step), - peel_first_iteration, for_body_generator); + void For(absl::string_view name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + For(name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return For(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) -> Status { - return for_body_generator(indvar); - }); + return ForWithStatus(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { - return for_body_generator(indvar); - }); + For(name, start, end, step, + /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) { + return for_body_generator(indvar); + }); } - Status For( + Status ForWithStatus( absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) -> Status { - return for_body_generator(indvar); - }); + return ForWithStatus(name, start, end, + llvm::ConstantInt::get(start->getType(), step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void ForReturnVoid( + void For( absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, - llvm::ConstantInt::get(start->getType(), step), - for_body_generator); + For(name, start, end, llvm::ConstantInt::get(start->getType(), step), + for_body_generator); } - Status For( + Status ForWithStatus( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - return For(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + return ForWithStatus(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } - void ForReturnVoid( + void For( absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/b_->getInt64(start), - /*end=*/b_->getInt64(end), - /*step=*/b_->getInt64(step), for_body_generator); + For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -201,38 +203,43 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(absl::string_view name, llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = - []() -> Status { return Status::OK(); }); + Status IfWithStatus( + absl::string_view name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() -> Status { + return Status::OK(); + }); - Status If(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = - []() -> Status { return Status::OK(); }) { - return If("", condition, true_block_generator, false_block_generator); + Status IfWithStatus( + llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() -> Status { + return Status::OK(); + }) { + return IfWithStatus("", condition, true_block_generator, + false_block_generator); } - void IfReturnVoid(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() { - }) { - IfReturnVoid("", condition, true_block_generator, false_block_generator); + void If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator = []() {}) { + If("", condition, true_block_generator, false_block_generator); } - void IfReturnVoid(absl::string_view name, llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() { - }) { - TF_CHECK_OK(If(name, condition, - [&]() { - true_block_generator(); - return Status::OK(); - }, - [&]() { - false_block_generator(); - return Status::OK(); - })); + void If( + absl::string_view name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() {}) { + TF_CHECK_OK(IfWithStatus( + name, condition, + [&]() { + true_block_generator(); + return Status::OK(); + }, + [&]() { + false_block_generator(); + return Status::OK(); + })); } using ArgumentVector = absl::Span; diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index 1aa85eb8d2d206bf0537deb659e779b24fffbb0a..cd8dd72cd775d5e0b52f96a2326367da0775e7eb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -123,7 +123,8 @@ KernelMappingScheme::KernelMappingScheme( dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()), tile_sizes_{1, tile_size_y, tile_size_x}, num_threads_x_(num_threads_x), - num_threads_y_(num_threads_y) { + num_threads_y_(num_threads_y), + dilated_x_(true) { DCHECK_EQ(dims_in_elems_.size(), 3); DCHECK_EQ(req_block_sizes.size(), 3); @@ -170,14 +171,16 @@ IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) { IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( const IrArray::Index& block_index) { - IrArray::Index tile_index = block_index; + DCHECK_EQ(block_index.size(), block_sizes_.size()); + std::vector multidim; + multidim.reserve(block_sizes_.size()); for (int i = 0; i < block_sizes_.size(); ++i) { - tile_index[i] = b_->CreateMul( + multidim.push_back(b_->CreateMul( block_index[i], llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]), - "block_origin." + std::to_string(i)); + "block_origin." + std::to_string(i))); } - return tile_index; + return IrArray::Index(multidim, block_index[0]->getType()); } IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( @@ -217,14 +220,14 @@ KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) { // defined by (num_thread_y, num_thread_x) from thread_id. llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, GetThreadsPerTile(), thread_id_raw); + llvm_ir::AddRangeMetadata(0, GetThreadsPerBlock(), thread_id_raw); llvm::Value* thread_id_int = b_->CreateIntCast(thread_id_raw, index_ty, /*isSigned=*/true, "thread.id.x"); llvm::Value* num_thread_x = llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX()); - llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x); - llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x); + llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x, "thread.x"); + llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x, "thread.y"); return std::make_tuple(y, x); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 7277aeac8ad2086a2f6419b1fdb60c4872841adc..f802cc27d519e621262f328903697373aa8c284c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -117,7 +117,10 @@ class KernelMappingScheme { int64 GetNumberOfTilesInOneBlock() const { return absl::c_accumulate(block_sizes_, 1, std::multiplies()); } - + int64 GetNumberOfTilesInOneBlockForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return block_sizes_[d]; + } int64 GetNumberOfBlocks() const { return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies()); } @@ -142,11 +145,21 @@ class KernelMappingScheme { int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; } - int64 GetThreadsPerTile() const { + int64 GetThreadsPerBlock() const { return GetNumberOfThreadsForDimensionX() * GetNumberOfThreadsForDimensionY(); } + bool DilatedX() const { return dilated_x_; } + void SetDilatedX(bool v) { + dilated_x_ = v; + if (!dilated_x_) { + // dilated_x_=false is for the purpose of vectorization, which requires + // GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_. + CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0); + } + } + IrArray::Index EmitBlockIndex(llvm::Type* index_ty); // Returns the index for the first tile in the block with the given block // index. @@ -186,6 +199,13 @@ class KernelMappingScheme { int64 num_threads_x_; // Number of threads used to process elements in the Y direction of a tile. int64 num_threads_y_; + + // When num_threads_x threads process a total of tile_size_x elements in the + // X dimension of a tile, each threads process n=tile_size_x/num_threads_x + // elements. When dilated_x=false, the n elements processed by a thread are + // contiguous. On the other hand, when dilated_x=true the n elements are + // dilated by a factor of num_threads_x. + bool dilated_x_; }; // A class to represent information for tiled parameters to support IR emission diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 219a9f221fbd116cdfbaf17985e21d82aefd079d..fe320bbe727111fbc986cc1fbc217feed74d30f1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -235,7 +235,7 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, absl::string_view suffix) { - std::vector dimensions(ShapeUtil::Rank(shape)); + std::vector dimensions(shape.rank()); std::iota(dimensions.begin(), dimensions.end(), 0); return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ceea24685af566e02340664f0a40c398c62b5ab0..807296329c07b8e4ac630486a1e1f59e4fdfa009 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -188,7 +188,16 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, } return cplx_t; } - // A Tuple contains an array of pointers. Use i8*. + case C128: { + auto cplx_t = module->getTypeByName("complex128"); + if (cplx_t == nullptr) { + return llvm::StructType::create( + {llvm::Type::getDoubleTy(module->getContext()), + llvm::Type::getDoubleTy(module->getContext())}, + "complex128", /*isPacked=*/true); + } + return cplx_t; + } // A Tuple contains an array of pointers. Use i8*. case TUPLE: // An Opaque is like a void*, use i8*. case OPAQUE: @@ -219,10 +228,10 @@ int GetSizeInBits(llvm::Type* type) { llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { // A tuple buffer is an array of pointers. result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); - } else if (ShapeUtil::IsArray(shape)) { + } else if (shape.IsArray()) { for (int64 dimension : LayoutUtil::MinorToMajor(shape)) { result_type = llvm::ArrayType::get(result_type, shape.dimensions(dimension)); @@ -621,6 +630,10 @@ llvm::Function* CreateFunction(llvm::FunctionType* function_type, function->setCallingConv(llvm::CallingConv::C); function->addFnAttr("no-frame-pointer-elim", "false"); + // Generate unwind information so that GDB can crawl through the stack frames + // created by the JIT compiled code. + function->setHasUWTable(); + if (enable_fast_math) { function->addFnAttr("unsafe-fp-math", "true"); function->addFnAttr("no-infs-fp-math", "true"); diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index e22c2173c271fc9571be1ddb0759d2b31562dc98..89b6a36f96beedbcb7322e6164ac59221650d3d8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -108,7 +108,7 @@ void EmitCompareLoopBody( // if (is_smaller_index && index_is_inbounds) KernelSupportLibrary ksl(b); - ksl.IfReturnVoid("smaller_comparison_index", do_comparison, [&]() { + ksl.If("smaller_comparison_index", do_comparison, [&]() { auto key1 = read_element(0, current_keys_index); auto key2 = read_element(0, compare_keys_index); auto compare_key1 = key1; @@ -155,7 +155,7 @@ void EmitCompareLoopBody( is_smaller_than = b->CreateOr( is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than)); } - ksl.IfReturnVoid("is_smaller_than", is_smaller_than, [&]() { + ksl.If("is_smaller_than", is_smaller_than, [&]() { // Swap key1 with key2. write_element(0, current_keys_index, key2); write_element(0, compare_keys_index, key1); @@ -192,7 +192,7 @@ void EmitTiledCompareLoop( b->CreateShl(tiled_keys_index[dimension_to_sort], value_one); // We want to copy two adjacent elements. We first check whether the // first index position is within bounds. - ksl.IfReturnVoid( + ksl.If( "smaller_keys_index", b->CreateICmpSLT(current_keys_index, tiled_keys_index.GetConstantWithIndexType( @@ -203,15 +203,14 @@ void EmitTiledCompareLoop( // Increment to go the next index position. current_keys_index = b->CreateAdd(current_keys_index, value_one); // Here we check whether the next index position is within bounds. - ksl.IfReturnVoid( - "inner_smaller_keys_index", - b->CreateICmpSLT(current_keys_index, - tiled_keys_index.GetConstantWithIndexType( - dimension_to_sort_bound)), - [&]() { - cache_index = b->CreateAdd(cache_index, value_one); - read_or_write(cache_index, current_keys_index); - }); + ksl.If("inner_smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + cache_index = b->CreateAdd(cache_index, value_one); + read_or_write(cache_index, current_keys_index); + }); }); }; @@ -253,7 +252,7 @@ void EmitTiledCompareLoop( if (dimension_to_sort_bound % tile_size) { // Otherwise we need a bounds check for the last tile. The last tile has // size 'dimension_to_sort_bound' % 'tile_size'. - ksl.IfReturnVoid( + ksl.If( "is_last_tile", b->CreateICmpUGE( b->CreateMul(tiled_keys_index[dimension_to_sort], @@ -323,7 +322,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, // comparisons). const Shape& keys_shape = keys_array.GetShape(); - int64 rank = ShapeUtil::Rank(keys_shape); + int64 rank = keys_shape.rank(); int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); std::vector dimensions_in_iteration_order(rank); std::vector iteration_order_to_logical_order(rank); diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index a60643bc754f896d096b3ca4e1216e77d7e384c6..d8d2700e1934fd202d44a1dc60e71a99913d4537 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -93,7 +93,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, llvm::LoadInst* src_buffer = b->CreateLoad(element_ptr); // Mark the loaded pointer as dereferenceable if we know its shape. - if (!ShapeUtil::IsOpaque(target_shape)) { + if (!target_shape.IsOpaque()) { SetDereferenceableMetadataForLoad( src_buffer, ByteSizeOf(target_shape, src_buffer->getModule()->getDataLayout())); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 6c89700983363fec46c41b5430c6eab6b366a1b6..3470fe5b2c34bf832207ed546fad176319446f31 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -52,8 +52,10 @@ namespace xla { } BackendOptions backend_options; - backend_options.set_platform(platform).set_intra_op_parallelism_threads( - options.intra_op_parallelism_threads()); + backend_options.set_platform(platform) + .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()) + .set_allowed_devices(options.allowed_devices()); + TF_ASSIGN_OR_RETURN(std::unique_ptr backend, Backend::CreateBackend(backend_options)); @@ -108,6 +110,7 @@ ExecutionOptions CreateExecutionOptions( *execution_options.mutable_shape_with_output_layout() = result_shape.ToProto(); } + execution_options.set_num_replicas(build_options.num_replicas()); return execution_options; } diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 9ccdd7d8d818b9fa3aa77cdd10d37ca18928b448..53d52d9a3d918fa6dee093668923fcfff963d084 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -198,7 +198,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) { continue; } - if (in_list.count(instr) > 0) { + if (in_list.contains(instr)) { continue; } int64 profit = GetProfit(instr, fusion); diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index ac2f79674feceff436c0e9c65338967f498e4473..e55b83d17e90bc2ca0053a0421cf80ef6edd5bca 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -27,13 +29,13 @@ namespace { bool IsAllowed(char character) { auto c = static_cast(character); - return (isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; + return (absl::ascii_isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; } } // namespace NameUniquer::NameUniquer(const string& separator) { - CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed)) + CHECK(absl::c_all_of(separator, IsAllowed)) << "separator should comprises allowed characters only"; separator_ = separator; } @@ -42,9 +44,10 @@ NameUniquer::NameUniquer(const string& separator) { if (name.empty()) { return ""; } + string result = name; char c = static_cast(result[0]); - if (!isalpha(c) && c != '_') { + if (!absl::ascii_isalpha(c) && c != '_') { result[0] = '_'; } for (int i = 1; i < result.length(); i++) { @@ -52,6 +55,13 @@ NameUniquer::NameUniquer(const string& separator) { result[i] = '_'; } } + + // HLO primitive type names (with the exception of 'tuple') are keywords in + // the HLO text representation and cannot be names, so append an underscore if + // the name is a primitive type. + if (primitive_util::IsPrimitiveTypeName(result) && result != "tuple") { + result += "_"; + } return result; } diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 3e2592c6ac626143f1421e545a31d9be91e376bc..d0d04147e0c29c66cba447550c0a9c703f35573a 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -104,5 +104,21 @@ TEST_F(NameUniquerTest, KeepNamesInRandomOrder) { EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3")); } +TEST_F(NameUniquerTest, AvoidKeywords) { + NameUniquer uniquer("."); + + EXPECT_EQ("f32_", uniquer.GetUniqueName("f32")); + EXPECT_EQ("s64_", uniquer.GetUniqueName("s64")); + EXPECT_EQ("pred_", uniquer.GetUniqueName("pred")); + + // Though a primitive type, "tuple" is not a keyword. + EXPECT_EQ("tuple", uniquer.GetUniqueName("tuple")); + + // Keywords are not capitalized. + EXPECT_EQ("F32", uniquer.GetUniqueName("F32")); + EXPECT_EQ("S32", uniquer.GetUniqueName("S32")); + EXPECT_EQ("Pred", uniquer.GetUniqueName("Pred")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc new file mode 100644 index 0000000000000000000000000000000000000000..701c629add52a217f16877a085b9ef2d096623d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.cc @@ -0,0 +1,106 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// Returns true if the given shape is a non-nested tuple. +bool IsNonNestedTuple(const Shape& shape) { + return shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape); +} + +} // namespace + +StatusOr OptimizeInputOutputBufferAlias::Build( + const Shape& input_shape, const Shape& output_shape, + HloInputOutputAliasConfig* alias_config) { + bool changed = false; + TF_RET_CHECK(LayoutUtil::HasLayout(input_shape)); + TF_RET_CHECK(LayoutUtil::HasLayout(output_shape)); + VLOG(1) << "input_shape:" << input_shape.ToString(); + VLOG(1) << "output_shape:" << output_shape.ToString(); + + // For all buffers defined by the parameter, build a map from the byte + // size to the list of the buffers of that size. + absl::flat_hash_map> size_to_input_index; + ShapeUtil::ForEachSubshape( + input_shape, [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsTuple()) { + return; + } + int64 bytes = size_func_(subshape); + size_to_input_index[bytes].push(index); + }); + + // For each result buffer shape index, take the first unused parameter + // buffer that matches the size. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + output_shape, [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsTuple()) { + return Status::OK(); + } + int64 bytes = size_func_(subshape); + + auto it = size_to_input_index.find(bytes); + if (it != size_to_input_index.end() && !it->second.empty()) { + changed = true; + const ShapeIndex& input_index = it->second.front(); + const ShapeIndex& output_index = index; + if (!alias_config->ParameterHasAlias(0, input_index) && + !alias_config->OutputHasAlias(output_index)) { + TF_RETURN_IF_ERROR(alias_config->SetUpAlias( + output_index, 0, input_index, + HloInputOutputAliasConfig::AliasKind::kSystemAlias)); + } + VLOG(3) << "Set up alias from with param index " + << it->second.front().ToString() << ", shape size " << bytes + << " and result subshape " + << ShapeUtil::HumanStringWithLayout(subshape) << " at index " + << index.ToString(); + it->second.pop(); + } + return Status::OK(); + })); + return changed; +} + +StatusOr OptimizeInputOutputBufferAlias::Run(HloModule* module) { + // User buffer alias only work for modules with 1 parameter. + if (module->entry_computation()->num_parameters() != 1) { + return false; + } + + HloInputOutputAliasConfig* alias_config = + &module->input_output_alias_config(); + + return Build(module->entry_computation()->parameter_instruction(0)->shape(), + module->entry_computation()->root_instruction()->shape(), + alias_config); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h new file mode 100644 index 0000000000000000000000000000000000000000..79ce468e975300ed703ae0fd780f4b9d5328a4b3 --- /dev/null +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// This pass opportunistically finds input and output buffers that can be +// aliased, and writes the alias config into the HloModule. +// +// The input and the output buffers can be in any shape, and each output buffer +// can alias with an input buffer with the same size. Each input buffer may only +// alias with a single output buffer. For example, for the following parameter +// and the output buffers, +// +// Parameters : { P1(2MiB), P2(4MiB), P3(8MiB), P4(4MiB), P5(4MiB), ... } +// Outputs : { O1(4MiB), O2(2MiB), O3(4MiB), O4(6MiB), O5(4MiB), ... } +// +// one potential aliasing would be (O1, P2), (O2, P1), (O3, P4), (O5, P5), .. +class OptimizeInputOutputBufferAlias : public HloModulePass { + using ShapeSizeFunction = std::function; + + public: + OptimizeInputOutputBufferAlias(ShapeSizeFunction size_func) + : size_func_(size_func) {} + ~OptimizeInputOutputBufferAlias() override = default; + + absl::string_view name() const override { + return "optimize_input_output_buffer_alias.h"; + } + + StatusOr Run(HloModule* module) override; + + private: + friend class OptimizeInputOutputBufferAliasTest; + + StatusOr Build(const Shape& input_shape, const Shape& output_shape, + HloInputOutputAliasConfig* alias_config); + ShapeSizeFunction size_func_ = nullptr; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ diff --git a/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..41e90f9b6931619fd9824e2eda25e12e4c7197b0 --- /dev/null +++ b/tensorflow/compiler/xla/service/optimize_input_output_buffer_alias_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h" + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +// Tests that UserBufferAlias properly maps input and output buffer indices of +// various shapes for aliasing. +class OptimizeInputOutputBufferAliasTest : public HloTestBase { + protected: + OptimizeInputOutputBufferAliasTest() { + r1f32_ = ShapeUtil::MakeShape(F32, {4}); + r2f32_ = ShapeUtil::MakeShape(F32, {4, 5}); + r3f32_ = ShapeUtil::MakeShape(F32, {4, 5, 6}); + r4f32_ = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + + auto size_func = [](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape); + }; + + optimize_pass_ = + absl::make_unique(size_func); + } + + // Returns the number of output indices that aliases with the input. + int64 AliasCount() { + int64 count = 0; + + config_.ForEachAlias( + [&](const ShapeIndex&, const HloInputOutputAliasConfig::Alias&) { + count++; + }); + return count; + } + + bool BuildAliasConfig(const Shape& input_shape, const Shape& output_shape) { + config_ = HloInputOutputAliasConfig(output_shape); + auto changed = optimize_pass_->Build(input_shape, output_shape, &config_); + TF_CHECK_OK(changed.status()); + + return changed.ValueOrDie(); + } + + std::unique_ptr optimize_pass_; + + HloInputOutputAliasConfig config_; + + Shape r1f32_; + Shape r2f32_; + Shape r3f32_; + Shape r4f32_; +}; + +// All shapes are different, so no aliasing is available. +TEST_F(OptimizeInputOutputBufferAliasTest, AllDifferentBufferSizes) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_}); + Shape output = ShapeUtil::MakeTupleShape({r3f32_, r4f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_FALSE(changed); + EXPECT_EQ(AliasCount(), 0); +} + +// Input and output shapes are equal, so buffers can alias at the same index. +TEST_F(OptimizeInputOutputBufferAliasTest, OrderedNonNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + Shape output = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + EXPECT_EQ(AliasCount(), 4); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{0}); + EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex{1}); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{2}); + EXPECT_EQ(config_.GetAliasedOutput(0, {3}), ShapeIndex{3}); +} + +// Only a subset of the tuple element shapes match between the input and the +// output. +TEST_F(OptimizeInputOutputBufferAliasTest, PartialReuseNonNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r1f32_, r2f32_, r2f32_}); + Shape output = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + + EXPECT_EQ(AliasCount(), 2); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{0}); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{1}); +} + +// The output shape is reverse of the input shape, but we can still reuse all +// the buffers. +TEST_F(OptimizeInputOutputBufferAliasTest, UnorderedNonNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_}); + Shape output = ShapeUtil::MakeTupleShape({r4f32_, r3f32_, r2f32_, r1f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + + EXPECT_EQ(AliasCount(), 4); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{3}); + EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex{2}); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{1}); + EXPECT_EQ(config_.GetAliasedOutput(0, {3}), ShapeIndex{0}); +} + +TEST_F(OptimizeInputOutputBufferAliasTest, UnorderedNestedTuple) { + Shape input = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({r1f32_}), r2f32_, r3f32_, r4f32_}); + Shape output = ShapeUtil::MakeTupleShape( + {r1f32_, ShapeUtil::MakeTupleShape({r3f32_, r2f32_}), r2f32_}); + bool changed = BuildAliasConfig(input, output); + EXPECT_TRUE(changed); + + EXPECT_EQ(AliasCount(), 3); + + EXPECT_EQ(config_.GetAliasedOutput(0, {0, 0}), ShapeIndex{0}); + EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex({1, 1})); + EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex({1, 0})); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index fb1645d9b2ebeae77190a950ebd023979c567016..9e3d1060210790f60243195a1c1dff13f1fc7fc5 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -64,6 +64,9 @@ namespace xla { // e.g. IsConstantScalar() or IsConstantScalar(42). // - WithFusionKind // - WithTupleIndex: get-tuple-element operations with the given tuple index +// - WithOneUse: Instruction is used as an operand exactly once. +// - WithOneUser: Instruction is used by exactly one other instruction, but +// is possibly used more than once as an operand (e.g. multiply(x,x)). // // Shape(): // - EqualTo @@ -772,7 +775,7 @@ class ShapePatternIsArrayImpl { explicit constexpr ShapePatternIsArrayImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - if (!ShapeUtil::IsArray(*shape)) { + if (!shape->IsArray()) { EXPLAIN << "Shape is not an array"; return false; } @@ -790,7 +793,7 @@ class ShapePatternIsTupleImpl { explicit constexpr ShapePatternIsTupleImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - if (!ShapeUtil::IsTuple(*shape)) { + if (!shape->IsTuple()) { EXPLAIN << "Shape is not a tuple"; return false; } @@ -828,7 +831,7 @@ class ShapePatternRankImpl { explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - if (ShapeUtil::Rank(*shape) != rank_) { + if (shape->rank() != rank_) { if (rank_ == 0) { EXPLAIN << "Shape is not a scalar"; } else { @@ -1133,6 +1136,13 @@ inline const HloInstruction* HloOperand(const HloInstruction* instr, return instr->operand(idx); } +// Pretty-printer for HloInstruction. Sort of like ToShortString, but with +// fewer %s and more shapes. +inline string InstToString(const HloInstruction* inst) { + return inst->ToString( + HloPrintOptions().set_print_metadata(false).set_print_percent(false)); +} + template class HloInstructionPattern; @@ -1187,14 +1197,14 @@ class HloInstructionIsImpl { bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { if (inst != inst_) { EXPLAIN << "HloInstruction " << inst << " is not " << inst_ << " (" - << inst_->ToShortString() << ")"; + << InstToString(inst_) << ")"; return false; } return true; } void DescribeTo(std::ostream* os, int64 indent = 0) const { - *os << "which is " << inst_ << " (" << inst_->ToShortString() << ")"; + *os << "which is " << inst_ << " (" << InstToString(inst_) << ")"; } private: @@ -1603,6 +1613,64 @@ class HloInstructionPatternParameterNumImpl { int64 parameter_num_; }; +// Superclass that contains common code used by Op::WithOneUse() and +// Op::WithOneUser(). +class HloInstructionPatternOneUseOrUserImpl { + protected: + bool MatchOneUser(const HloInstruction* inst, MatchOption option) const { + if (inst->user_count() != 1) { + EXPLAIN << "HloInstruction has " << inst->user_count() + << " users, but expected exactly one."; + if (inst->user_count() > 1) { + EXPLAIN << "\nAll users:"; + for (const HloInstruction* user : inst->users()) { + EXPLAIN << "\n - " << InstToString(user); + } + } + return false; + } + return true; + } +}; + +class HloInstructionPatternOneUseImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + if (!MatchOneUser(inst, option)) { + return false; + } + + int64 use_count = absl::c_count_if( + inst->users()[0]->operands(), + [&](const HloInstruction* operand) { return operand == inst; }); + if (use_count != 1) { + EXPLAIN << "HloInstruction is used " << use_count + << " times by its user, but is expected to be used just once: " + << InstToString(inst->users()[0]); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one use"; + } +}; + +class HloInstructionPatternOneUserImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + return MatchOneUser(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one user (but possibly is used multiple times by " + "that instruction)"; + } +}; + // Matches a constant scalar or effective scalar, optionally with a given value. template class HloConstantScalarImpl { @@ -1669,7 +1737,8 @@ class HloConstantScalarImpl { literal_r0_as_val_ty_or.ValueOrDie() == val_literal && literal_r0 == val_as_literal_ty; if (!rv) { - EXPLAIN << "HloInstruction's constant value " << literal_r0.ToString() + EXPLAIN << "HloInstruction's constant value " + << literal_r0.ToStringWithoutShape() << " did not match expected value " << *val_; } return rv; @@ -1706,10 +1775,7 @@ class HloInstructionPattern { return true; } if (inst != nullptr) { - EXPLAIN << "\nin " - << inst->ToString(HloPrintOptions() - .set_print_metadata(false) - .set_print_percent(false)); + EXPLAIN << "\nin " << InstToString(inst); } return false; } @@ -1722,10 +1788,7 @@ class HloInstructionPattern { } return true; } - EXPLAIN << "\nin " - << inst->ToString(HloPrintOptions() - .set_print_metadata(false) - .set_print_percent(false)); + EXPLAIN << "\nin " << InstToString(inst); return false; } @@ -1815,7 +1878,7 @@ class HloInstructionPattern { // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) + constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const -> decltype(this->WithShape(Shape().EqualTo(shape))) { return WithShape(Shape().EqualTo(shape)); } @@ -1823,7 +1886,7 @@ class HloInstructionPattern { // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) + constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const -> decltype(this->WithShape(Shape().CompatibleTo(shape))) { return WithShape(Shape().CompatibleTo(shape)); } @@ -1877,6 +1940,22 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num)); } + // Modifies the pattern to match if the instruction is used exactly once. + // Does not match if the instruction is used twice by the same user (e.g. + // multiply(x,x)). + constexpr auto WithOneUse() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) { + return AppendImpl(HloInstructionPatternOneUseImpl()); + } + + // Modifies the pattern to match if the instruction is used by exactly one + // other instruction. Will match if the instruction is used twice, so long as + // it's by the same user (e.g. multiply(x,x)). + constexpr auto WithOneUser() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) { + return AppendImpl(HloInstructionPatternOneUserImpl()); + } + void DescribeTo(std::ostream* os, int64 indent = 0) const { impl_.DescribeTo(os, indent); } @@ -1922,6 +2001,7 @@ Op(::xla::HloInstruction** matched_inst) { XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) XLA_NULLOP_PATTERN(Iota) +XLA_NULLOP_PATTERN(Rng) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. @@ -1956,7 +2036,7 @@ XLA_UNOP_PATTERN(Ceil) XLA_UNOP_PATTERN(Convert) XLA_UNOP_PATTERN(Copy) XLA_UNOP_PATTERN(Cos) -XLA_UNOP_PATTERN(CrossReplicaSum) +XLA_UNOP_PATTERN(AllReduce) XLA_UNOP_PATTERN(Exp) XLA_UNOP_PATTERN(Fft) XLA_UNOP_PATTERN(Floor) @@ -1977,7 +2057,6 @@ XLA_UNOP_PATTERN(SendDone) XLA_UNOP_PATTERN(Sign) XLA_UNOP_PATTERN(Sin) XLA_UNOP_PATTERN(Slice) -XLA_UNOP_PATTERN(Sort) XLA_UNOP_PATTERN(Tanh) XLA_UNOP_PATTERN(Transpose) #undef XLA_UNOP_PATTERN @@ -2028,10 +2107,10 @@ XLA_UNOP_PATTERN(Transpose) } \ template \ inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(NAME##AnyOrder( \ + ->decltype(NAME##AnyOrder( \ nullptr, std::forward(lhs), std::forward(rhs))) { \ - return NAME##AnyOrder(nullptr, std::forward(lhs), \ - std::forward(rhs)); \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ } XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) @@ -2039,7 +2118,6 @@ XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) XLA_BINOP_PATTERN(Convolution) XLA_BINOP_PATTERN(Dot) -XLA_BINOP_PATTERN(DynamicSlice) XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) @@ -2053,6 +2131,7 @@ XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Pad) XLA_BINOP_PATTERN(Power) +XLA_BINOP_PATTERN(ReduceWindow) XLA_BINOP_PATTERN(Remainder) XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) @@ -2099,6 +2178,7 @@ XLA_BINOP_PATTERN(ShiftRightLogical) .WithOperand(2, std::forward(arg2)); \ } XLA_TERNOP_PATTERN(Clamp); +XLA_TERNOP_PATTERN(Scatter); XLA_TERNOP_PATTERN(Select); #undef XLA_TERNOP_PATTERN @@ -2151,9 +2231,13 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, // We could implement all ops as "variadic" ops, but it would make the // already-bad compile errors even worse. +XLA_VARIADIC_OP_PATTERN(AfterAll); XLA_VARIADIC_OP_PATTERN(Concatenate); XLA_VARIADIC_OP_PATTERN(CustomCall); +XLA_VARIADIC_OP_PATTERN(DynamicSlice) +XLA_VARIADIC_OP_PATTERN(Map) XLA_VARIADIC_OP_PATTERN(Reduce); +XLA_VARIADIC_OP_PATTERN(Sort); XLA_VARIADIC_OP_PATTERN(Tuple); // Helpers for matching non-constant instructions. diff --git a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc index 9ca2fb05c1f7ef093c58237cf21fbc7c813a592a..f51a18b13894d75300c46835fabd82a4ce0699af 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc @@ -23,7 +23,6 @@ namespace xla { namespace { namespace m = ::xla::match; -using ::testing::Eq; using ::testing::Not; template diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 13886fa6f5b7b55283e6e420734a22312987d8a6..5c3c009a68bffbda8642fceedfb724879fbf1530 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -242,8 +242,8 @@ TEST(PatternMatcherTest, ConstantScalar) { HloModule test_module ENTRY test { a = s32[] constant(1) - b = s32[1,1] constant(s32[1,1]{{2}}) - c = s32[1,2] constant(s32[1,2]{{2,2}}) + b = s32[1,1] constant({{2}}) + c = s32[1,2] constant({{2,2}}) d = f32[] constant(1) e = f32[] constant(1.25) ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e) @@ -767,10 +767,11 @@ TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) { "in c = f64[] constant(2.25)"); EXPECT_DESC_AND_EXPLANATION( constant, m::Op().Is(iota.get()), - absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()), " (", - iota->ToShortString(), ")"), + absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)"), absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x", - absl::Hex(iota.get()), " (", iota->ToShortString(), ")\n", + absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)\n" "in c = s32[] constant(0)")); } @@ -875,5 +876,60 @@ TEST(PatternMatcherTest, Parameter) { "in p0 = f32[] parameter(0)"); } +TEST(PatternMatcherTest, OneUseAndOneUser) { + auto param = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUse(), + "an HloInstruction which has exactly one use", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUser(), + "an HloInstruction which has exactly one user (but possibly is used " + "multiple times by that instruction)", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + { + auto reshape = + SetName("r", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + + auto reshape1 = + SetName("r1", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + + const char* kMultipleUserExplanation = + "HloInstruction has 2 users, but expected exactly one.\n" + "All users:\n" + " - r = f32[1]{0} reshape(f32[] p0)\n" + " - r1 = f32[1]{0} reshape(f32[] p0)\n" + "in p0 = f32[] parameter(0)"; + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + kMultipleUserExplanation); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()), + kMultipleUserExplanation); + } + + auto add = SetName("add", HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, + param.get(), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + "HloInstruction is used 2 times by its user, but is expected to be " + "used just once: add = f32[] add(f32[] p0, f32[] p0)\n" + "in p0 = f32[] parameter(0)"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index c227106511c2c17b44569d3b696cd7d764226e81..886a0545624927fa77528141f61d8ecb6bec180a 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -70,6 +70,9 @@ PlatformUtil::GetSupportedPlatforms() { for (se::Platform* platform : all_platforms) { auto compiler_status = Compiler::GetForPlatform(platform); if (compiler_status.ok()) { + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } platforms.push_back(platform); } else { LOG(INFO) << "platform " << platform->Name() << " present but no " @@ -205,7 +208,9 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) { } /* static */ StatusOr> -PlatformUtil::GetStreamExecutors(se::Platform* platform) { +PlatformUtil::GetStreamExecutors( + se::Platform* platform, + const absl::optional>& allowed_devices) { int device_count = platform->VisibleDeviceCount(); if (device_count <= 0) { return NotFound("no %s devices found", platform->Name()); @@ -226,6 +231,17 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { tensorflow::thread::ThreadPool thread_pool( tensorflow::Env::Default(), "device_initialization", device_count); for (int i = 0; i < device_count; ++i) { + // Once a stream executor is instantiated it will cause allocations on + // the device, for example for GPUs cuda context, cudnn handles etc. will + // be constructed. By constructing stream executors only on the + // allowed_devices, we don't make any allocations on other devices. + // This helps in multi-process executions on the same host like horovod or + // shared hosts. + if (allowed_devices && allowed_devices->count(i) == 0) { + VLOG(1) << "Not initializing StreamExecutor for device " << i + << " since it is not in the visible device list"; + continue; + } thread_pool.Schedule([platform, i, &stream_executors]() { VLOG(1) << "Started device init " << i; se::StreamExecutorConfig config; @@ -247,8 +263,8 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { // Block here in thread_pool destructor until all devices are initialized. } VLOG(1) << "Device initialization complete"; - if (std::all_of(stream_executors.begin(), stream_executors.end(), - [](se::StreamExecutor* s) { return s == nullptr; })) { + if (absl::c_all_of(stream_executors, + [](se::StreamExecutor* s) { return s == nullptr; })) { return InternalError("no supported devices found for platform %s", platform->Name()); } diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index 571451ba43a81d19b70e4954e45d3447f15dcedc..592b20282f334e12e0d7a7f683c9a6ab59d21fea 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ +#include #include #include @@ -60,10 +61,14 @@ class PlatformUtil { // Returns a vector of StreamExecutors for the given platform. The vector is // indexed by device ordinal (device numbering used by StreamExecutor). If an // element is nullptr, then the device is present by not supported by XLA. + // If populated, only the devices in allowed_devices will have + // their StreamExecutors initialized, otherwise all StreamExecutors will be + // initialized and returned. // // If the platform has no visible devices, a not-found error is returned. static StatusOr> GetStreamExecutors( - se::Platform* platform); + se::Platform* platform, + const absl::optional>& allowed_devices = absl::nullopt); private: TF_DISALLOW_COPY_AND_ASSIGN(PlatformUtil); diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 4df746fca9f8320eed72911726f33bb01f06fed5..a62118df157edf67114ff41befbdce3da129fe93 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -226,7 +226,10 @@ StatusOr PerformSinkReshapeOrTranspose( // changes, so all the fused instructions have the same dimensions. for (const auto& fused_instruction : instruction->fused_instructions()) { Shape* shape = fused_instruction->mutable_shape(); - *shape->mutable_dimensions() = new_operand_shape.dimensions(); + shape->clear_dimensions(); + for (int64 i : new_operand_shape.dimensions()) { + shape->add_dimensions(i); + } *shape->mutable_layout() = new_operand_shape.layout(); } } diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 11c2f8392d285095816dd5d61f7029c1bfd158d4..acad871c4d427b174ffce3a462a0a3918a1e0c33 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -26,7 +26,6 @@ limitations under the License. namespace xla { - // Transposes the given scatter_indices such that the index_vector_dim becomes // the most-minor dimension. static StatusOr TransposeIndexVectorDimToLast( @@ -60,6 +59,13 @@ static StatusOr CanonicalizeScatterIndices( TF_ASSIGN_OR_RETURN( HloInstruction * transposed_scatter_indices, TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + if (scatter_indices->shape().rank() == index_vector_dim + 1 && + scatter_indices->shape().dimensions(index_vector_dim) == 1) { + auto new_shape = + ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape()); + TF_ASSIGN_OR_RETURN(scatter_indices, + MakeReshapeHlo(new_shape, scatter_indices)); + } bool indices_are_scalar = index_vector_dim == scatter_indices->shape().dimensions_size(); @@ -88,7 +94,7 @@ static StatusOr CanonicalizeScatterIndices( static StatusOr PermuteScatterAndWindowDims( HloInstruction* updates, absl::Span update_window_dims) { std::vector permutation; - const int64 updates_rank = ShapeUtil::Rank(updates->shape()); + const int64 updates_rank = updates->shape().rank(); permutation.reserve(updates_rank); for (int64 i = 0; i < updates_rank; ++i) { @@ -165,10 +171,9 @@ static StatusOr CheckIndexValidity( // Valid range for the index: [0, operand_dims - window_sizes] // Check if the index has any negative values. - TF_ASSIGN_OR_RETURN( - HloInstruction * zero_index, + HloInstruction* zero_index = BroadcastZeros(computation, index->shape().element_type(), - AsInt64Slice(index->shape().dimensions()))); + AsInt64Slice(index->shape().dimensions())); TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check, MakeBinaryHlo(HloOpcode::kLe, zero_index, index)); @@ -214,15 +219,11 @@ static StatusOr> ScatterLoopBody( HloInstruction* updates = loop_state[2]; bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1; - CHECK_EQ(has_scalar_indices, - dim_numbers.index_vector_dim() == - scatter->operand(1)->shape().dimensions_size()); // Build a vector form of the induction variable of the while loop. - TF_ASSIGN_OR_RETURN( - HloInstruction * induction_var_as_vector, + HloInstruction* induction_var_as_vector = MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, - /*result_shape_bounds=*/{1})); + /*result_shape_bounds=*/{1}); // Pick the index to scatter from scatter_indices based on the induction_var // and transform that to an index into the `operand` space. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 3b336d5c9db80ff2ca8d0e45396dca66a29a0494..9bda6fba3aabfed78ae724545387e86bad36c886 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" @@ -113,6 +114,16 @@ int ServiceOptions::intra_op_parallelism_threads() const { return intra_op_parallelism_threads_; } +ServiceOptions& ServiceOptions::set_allowed_devices( + const absl::optional>& allowed_devices) { + allowed_devices_ = allowed_devices; + return *this; +} + +const absl::optional>& ServiceOptions::allowed_devices() const { + return allowed_devices_; +} + /* static */ StatusOr> Service::NewService( se::Platform* platform) { ServiceOptions default_options; @@ -129,6 +140,7 @@ int ServiceOptions::intra_op_parallelism_threads() const { } BackendOptions backend_options; backend_options.set_platform(platform); + backend_options.set_allowed_devices(options.allowed_devices()); TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options)); std::unique_ptr service( @@ -150,17 +162,13 @@ Service::Service(const ServiceOptions& options, LOG(INFO) << StrFormat( "XLA service %p executing computations on platform %s. Devices:", this, execute_backend_->platform()->Name()); + auto stream_executors = execute_backend_->stream_executors(); for (int i = 0; i < execute_backend_->device_count(); ++i) { - if (execute_backend_->device_ordinal_supported(i)) { - se::StreamExecutor* executor = - execute_backend_->stream_executor(i).ValueOrDie(); - const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i, - description.name(), - description.platform_version()); - } else { - LOG(INFO) << StrFormat(" StreamExecutor device (%d) not supported", i); - } + se::StreamExecutor* executor = stream_executors.at(i); + const auto& description = executor->GetDeviceDescription(); + LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i, + description.name(), + description.platform_version()); } } else { VLOG(1) << "XLA compile-only service constructed"; @@ -288,11 +296,16 @@ StatusOr> Service::CreateModuleConfig( computation_layout->mutable_result_layout()->SetToDefaultLayout(); } - config->set_replica_count(options_.number_of_replicas()); if (execution_options != nullptr) { + if (execution_options->num_replicas() > 0) { + config->set_replica_count(execution_options->num_replicas()); + } else { + config->set_replica_count(options_.number_of_replicas()); + } config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); } else { + config->set_replica_count(options_.number_of_replicas()); config->set_debug_options(GetDebugOptionsFromFlags()); } @@ -355,6 +368,7 @@ StatusOr>> Service::BuildExecutables( const HloModuleProto* proto = module_protos[i]; const HloModuleConfig& config = *module_configs[i]; TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); + TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); module_group->push_back(std::move(module)); } @@ -516,13 +530,13 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const absl::Span> arguments, - Backend* backend, const string& result_tag, ExecutionProfile* profile) { + absl::Span> arguments, + Backend* backend, const DeviceHandle& device_handle, + const string& result_tag, ExecutionProfile* profile) { // Set up streams. std::vector streams; - TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*backend, SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle)); TF_RET_CHECK(!replicas.empty()); for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, @@ -530,10 +544,11 @@ StatusOr Service::ExecuteAndRegisterResult( streams.push_back(std::move(stream)); } - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - backend->computation_placer()->AssignDevices( - options_.number_of_replicas(), - /*computation_count=*/1)); + DeviceAssignment device_assignment(options_.number_of_replicas(), + /*computation_count=*/1); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + device_assignment(replica, 0) = replicas[replica]->device_ordinal(); + } // Set up run options. std::vector run_options; @@ -545,9 +560,7 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); - run_options.emplace_back( - options, backend->StreamBorrower(), - /*xla_intra_op_thread_pool=*/backend->eigen_intra_op_thread_pool()); + run_options.emplace_back(options, backend->StreamBorrower()); } if (options_.number_of_replicas() == 1) { @@ -704,14 +717,33 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, } } - // Execute the generated executables in parallel and return the device - // handles for each computation's output. + // If we have multiple executables to run, execute them all in parallel. But + // if we only have one executable, execute it using the vanilla, non-parallel + // call. + // + // We do this because the Client API uses ExecuteGraphParallel when it wants + // to compile and run one computation without caching the executable, but not + // all backends support the async StreamExecutor API required by + // ExecuteParallelAndRegisterResult. + // + // TODO(b/122731460): Consolidate Execute{,Parallel}AndRegisterResult; they do + // basically the same thing. ExecutionProfile profile; - TF_ASSIGN_OR_RETURN( - std::vector outputs, - ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), device_handles, - computation_names, &profile)); + std::vector outputs; + if (executable_ptrs.size() == 1) { + TF_ASSIGN_OR_RETURN( + auto output, + ExecuteAndRegisterResult(executable_ptrs[0], all_arguments[0], + execute_backend_.get(), device_handles[0], + computation_names[0], &profile)); + outputs.push_back(std::move(output)); + } else { + TF_ASSIGN_OR_RETURN( + outputs, ExecuteParallelAndRegisterResult( + executable_ptrs, all_arguments, execute_backend_.get(), + device_handles, computation_names, &profile)); + } + for (const GlobalDataHandle& output : outputs) { ExecuteResponse response; *response.mutable_output() = output; @@ -746,9 +778,9 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, } if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( - "Requested device count (%d) exceeds the number of available devices " - "on the target (%d)", - arg->device_count(), available_device_count); + "Requested logical device count (%d) with replica count (%d) exceeds " + "the number of available physical devices on the target (%d)", + arg->device_count(), replica_count, available_device_count); } for (int64 i = 0; i < arg->device_count(); ++i) { @@ -897,6 +929,7 @@ Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { *result->mutable_output(), ExecuteAndRegisterResult(executable.get(), replicated_arguments, execute_backend_.get(), + SingleComputationDeviceHandle(), "result of " + executable->module().name(), result->mutable_profile())); @@ -1078,9 +1111,11 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ProgramShape program_shape(arg->computation().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); + absl::optional output_layout; if (arg->has_output_layout()) { + output_layout = Layout::CreateFromProto(arg->output_layout()); TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), program_shape.result())); + *output_layout, program_shape.result())); } HloModuleConfig config(program_shape); @@ -1088,16 +1123,19 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, TF_ASSIGN_OR_RETURN(std::unique_ptr module, CreateModuleFromProto(arg->computation(), config)); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(module.get())); + HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( - *module, /*arg_literals=*/{})); + evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference); + TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {})); // Since the result layout is non-effective to the Evaluator results, explicit // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal.Relayout(arg->output_layout()); + if (output_layout.has_value()) { + result_literal = result_literal.Relayout(*output_layout); } *result->mutable_literal() = result_literal.ToProto(); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 11e1a79552fbd944ab28da129b08cfe676fb08e9..fd907d07daef9e8337aeed198ef4fd23d069df21 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include @@ -52,7 +53,7 @@ class ServiceOptions { ServiceOptions& set_platform(se::Platform* platform); se::Platform* platform() const; - // Set the number of replicas to use when compiling replicated + // Set the default number of replicas to use when compiling replicated // programs. ServiceOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; @@ -61,10 +62,17 @@ class ServiceOptions { ServiceOptions& set_intra_op_parallelism_threads(int num_threads); int intra_op_parallelism_threads() const; + // Sets the allowed_devices set for selectively constructing stream executors + // on the platform. + ServiceOptions& set_allowed_devices( + const absl::optional>& allowed_devices); + const absl::optional>& allowed_devices() const; + private: se::Platform* platform_ = nullptr; int number_of_replicas_ = 1; int intra_op_parallelism_threads_ = -1; + absl::optional> allowed_devices_; }; // The XLA service object, which is the same across all platforms. It maintains @@ -242,8 +250,9 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const absl::Span> arguments, - Backend* backend, const string& result_tag, ExecutionProfile* profile); + absl::Span> arguments, + Backend* backend, const DeviceHandle& device_handle, + const string& result_tag, ExecutionProfile* profile); // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index dbfed628bfcabffe66bef41a82e0e2430897d80d..6bee671056552b83014367889320b748659bbfdf 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -32,12 +32,10 @@ class ServiceExecutableRunOptions { ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} - explicit ServiceExecutableRunOptions( - ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr, - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr) + explicit ServiceExecutableRunOptions(ExecutableRunOptions run_options, + StreamBorrower borrow_stream = nullptr) : run_options_(std::move(run_options)), - borrow_stream_(std::move(borrow_stream)), - xla_intra_op_thread_pool_(xla_intra_op_thread_pool) {} + borrow_stream_(std::move(borrow_stream)) {} // Returns reference or pointer to `ExecutableRunOptions` member. const ExecutableRunOptions& run_options() const { return run_options_; } @@ -56,15 +54,9 @@ class ServiceExecutableRunOptions { : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); } - // Returns reference to thread pool for execution of XLA ops on CPU backend. - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool() const { - return xla_intra_op_thread_pool_; - } - private: ExecutableRunOptions run_options_; StreamBorrower borrow_stream_; - tensorflow::thread::ThreadPool* xla_intra_op_thread_pool_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 7e7282a737041458aed39b0054f901c23aa87d7a..3ddcaae193ba266f35fa6f9922fe4f3a4970cdc5 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" -#include #include +#include #include #include #include @@ -50,7 +50,7 @@ bool AllUnique(absl::Span slice) { } Status ExpectArray(const Shape& shape, absl::string_view op_type) { - if (!ShapeUtil::IsArray(shape)) { + if (!shape.IsArray()) { return InvalidArgument("Expected array argument for %s, but got %s.", string(op_type), ShapeUtil::HumanString(shape)); } @@ -70,7 +70,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, const Shape& accumulator_shape = reducer_shape.result(); std::vector accumulator_subshapes; - if (ShapeUtil::IsArray(accumulator_shape)) { + if (accumulator_shape.IsArray()) { if (inputs != 1) { return InvalidArgument( "Reduction function must produce a tuple with %d elements, but " @@ -78,7 +78,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, inputs); } accumulator_subshapes.push_back(&accumulator_shape); - } else if (ShapeUtil::IsTuple(accumulator_shape)) { + } else if (accumulator_shape.IsTuple()) { if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { return InvalidArgument( "Reduction function must produce a tuple with %d elements, but has " @@ -96,7 +96,7 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, } for (const Shape* element_shape : accumulator_subshapes) { - if (ShapeUtil::Rank(*element_shape) != 0) { + if (element_shape->rank() != 0) { return InvalidArgument( "Reduction function must return a scalar or tuple of scalars but " "returns shape: %s", @@ -156,17 +156,26 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape, return Status::OK(); } +bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { + return window_dimension.size() == 1 && window_dimension.stride() == 1 && + window_dimension.padding_low() == 0 && + window_dimension.padding_high() == 0 && + window_dimension.window_dilation() == 1 && + window_dimension.base_dilation() == 1; +} + StatusOr InferWindowOutputShape(const Shape& base_shape, const Window& window, PrimitiveType element_type, bool allow_negative_padding) { - if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) { + if (window.dimensions_size() != base_shape.rank()) { return InvalidArgument( "Window has dimension %d but base shape has dimension %d.", - window.dimensions_size(), ShapeUtil::Rank(base_shape)); + window.dimensions_size(), base_shape.rank()); } std::vector output_dimensions(window.dimensions_size()); + std::vector output_is_dynamic(window.dimensions_size()); for (int64 i = 0; i < window.dimensions_size(); ++i) { const auto& dim = window.dimensions(i); if (dim.size() <= 0) { @@ -196,6 +205,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, window.DebugString()); } + if (base_shape.is_dynamic_dimension(i) && !IsTrivialWindowDimension(dim)) { + return Unimplemented( + "Dynamic shape is not supported for non trivial window: %s", + window_util::ToString(window)); + } + const int64 dilated_base = window_util::DilatedBound( ShapeUtil::GetDimension(base_shape, i), dim.base_dilation()); const int64 padded_dilated_base = @@ -205,9 +220,11 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, output_dimensions[i] = window_util::StridedBound( padded_dilated_base, dilated_window, dim.stride()); + output_is_dynamic[i] = base_shape.is_dynamic_dimension(i); } - return ShapeUtil::MakeValidatedShape(element_type, output_dimensions); + return ShapeUtil::MakeValidatedShape(element_type, output_dimensions, + output_is_dynamic); } } // namespace @@ -338,7 +355,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument."); } - if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { + if (dimension < 0 || dimension >= arg_shapes[0]->rank()) { return InvalidArgument("Concatenate dimension out of bounds: %d.", dimension); } @@ -351,12 +368,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, element_type = arg_shape->element_type(); continue; } - if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { + if (arg_shape->rank() != shape->rank()) { return InvalidArgument( "Cannot concatenate arrays with different ranks: %d (%s) vs %d " "(%s).", - ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape), - ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape)); + arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(), + ShapeUtil::HumanString(*shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( @@ -364,8 +381,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, PrimitiveType_Name(arg_shape->element_type()), PrimitiveType_Name(shape->element_type())); } - for (int64 dimension_number = 0; - dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) { + for (int64 dimension_number = 0; dimension_number < arg_shape->rank(); + ++dimension_number) { if (arg_shape->dimensions(dimension_number) != shape->dimensions(dimension_number)) { if (dimension_number == dimension) { @@ -401,7 +418,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape), PrimitiveType_Name(new_element_type)); } - if (!ShapeUtil::IsArray(operand_shape) || + if (!operand_shape.IsArray() || !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions @@ -424,7 +441,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape), PrimitiveType_Name(new_element_type)); } - if (!ShapeUtil::IsArray(operand_shape) || + if (!operand_shape.IsArray() || !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions @@ -472,7 +489,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { - if (!ShapeUtil::IsArray(operand_shape)) { + if (!operand_shape.IsArray()) { return InvalidArgument( "Pad operation does not support tuple-shape operands."); } @@ -480,7 +497,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Pad operation does not support non-scalar padding values."); } - if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) { + if (operand_shape.rank() != padding_config.dimensions_size()) { return InvalidArgument( "The rank of the operand and the padding configuration do not match: " "%s vs %s.", @@ -500,35 +517,40 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, padding_config.ShortDebugString()); } - std::vector dimensions(ShapeUtil::Rank(operand_shape)); + if (!padding_value_shape.is_static()) { + return InvalidArgument("Dynamic padding value is not supported"); + } + + std::vector dimensions(operand_shape.rank()); + std::vector is_dynamic(operand_shape.rank()); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { const auto& p = padding_config.dimensions(i); + if (operand_shape.is_dynamic_dimension(i) && p.edge_padding_high() != 0 && + p.edge_padding_low() != 0 && p.interior_padding() != 0) { + return InvalidArgument( + "Dynamic dimension on padding dimension is not supported."); + } dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() + p.edge_padding_high() + std::max(operand_shape.dimensions(i) - 1, 0LL) * p.interior_padding(); + is_dynamic[i] = operand_shape.is_dynamic_dimension(i); } + return ShapeUtil::MakeShape( ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), - dimensions); + dimensions, is_dynamic); } // Current DotDimensionNumbers Requirements: // // Contracting Dimensions: -// *) Exactly one contracting dimension on both lhs and rhs. +// *) Same number of contracting dimensions on both lhs and rhs. // *) Contracting dimension size must be the same on both lhs and rhs. -// *) Contracting dimension numbers do not need to be the same (i.e. transposes -// are passed on to emitter implementations). // // Batch Dimensions: // *) Same number of batch dimensions on both lhs and rhs. -// *) Same batch dimension numbers (and sizes) on both lhs and rhs. -// *) Batch dimension numbers must be ordered before contracting and -// non-contracting/non-batch dimension numbers. -// -// Non-Contracting-Non-Batch Dimensions: -// *) Can be 0 (matrix-vector) or 1 (matrix-matrix). +// *) Same batch dimension sizes on both lhs and rhs. // namespace { @@ -541,9 +563,8 @@ Status ValidateDotDimensionNumbers( absl::Span contracting_dims, absl::Span batch_dims) -> bool { auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; - return std::all_of(contracting_dims.begin(), contracting_dims.end(), - in_range) && - std::all_of(batch_dims.begin(), batch_dims.end(), in_range); + return absl::c_all_of(contracting_dims, in_range) && + absl::c_all_of(batch_dims, in_range); }; absl::Span lhs_contracting_dimensions = @@ -555,9 +576,9 @@ Status ValidateDotDimensionNumbers( absl::Span rhs_batch_dimensions = AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); - if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, + if (!dims_in_range(lhs.rank(), lhs_contracting_dimensions, lhs_batch_dimensions) || - !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, + !dims_in_range(rhs.rank(), rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is out of range in Dot: %s.", dimension_numbers.DebugString()); @@ -570,9 +591,8 @@ Status ValidateDotDimensionNumbers( auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; }; - return std::all_of(contracting_dims.begin(), contracting_dims.end(), - is_unique) && - std::all_of(batch_dims.begin(), batch_dims.end(), is_unique); + return absl::c_all_of(contracting_dims, is_unique) && + absl::c_all_of(batch_dims, is_unique); }; if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || @@ -581,36 +601,6 @@ Status ValidateDotDimensionNumbers( dimension_numbers.DebugString()); } - // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. - const int64 lhs_non_contracting_non_batch_dims = - ShapeUtil::Rank(lhs) - - dimension_numbers.lhs_contracting_dimensions_size() - - dimension_numbers.lhs_batch_dimensions_size(); - const int64 rhs_non_contracting_non_batch_dims = - ShapeUtil::Rank(rhs) - - dimension_numbers.rhs_contracting_dimensions_size() - - dimension_numbers.rhs_batch_dimensions_size(); - if (lhs_non_contracting_non_batch_dims < 0 || - lhs_non_contracting_non_batch_dims > 1 || - rhs_non_contracting_non_batch_dims < 0 || - rhs_non_contracting_non_batch_dims > 1) { - return InvalidArgument( - "Batch and contracting dimension number mismatch with rank."); - } - - // Check that batch dimension numbers are ordered before all others, and - // that they are monotonically increasing. - std::vector batch_dim_numbers(lhs_batch_dimensions.size()); - std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0); - if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), - lhs_batch_dimensions.begin()) || - !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), - rhs_batch_dimensions.begin())) { - return InvalidArgument( - "Batch dimension numbers must precede non-batch dimensions and be" - "monotonically increasing."); - } - return Status::OK(); } @@ -637,28 +627,33 @@ Status ValidateDotDimensionNumbers( return fail("Element types do not match."); } - if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { + if ((lhs.rank() < 1) || (rhs.rank() < 1)) { return fail("Dot only supports rank 1 or above."); } // Validate basic properties of dot dimension numbers. TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); - // Check that there is only one contracting dimension for both lhs and rhs. + // Check that number of contracting dimensions match. if (dimension_numbers.lhs_contracting_dimensions_size() != - dimension_numbers.rhs_contracting_dimensions_size() || - dimension_numbers.lhs_contracting_dimensions_size() != 1) { - return fail("Must specify one contracting dimension for both lhs and rhs."); + dimension_numbers.rhs_contracting_dimensions_size()) { + return fail( + "Must specify the same number of contracting dimensions for lhs and " + "rhs."); } - // Check that contracting dimension sizes match. - const int64 lhs_contracting_dimension = - dimension_numbers.lhs_contracting_dimensions(0); - const int64 rhs_contracting_dimension = - dimension_numbers.rhs_contracting_dimensions(0); - if (lhs.dimensions(lhs_contracting_dimension) != - rhs.dimensions(rhs_contracting_dimension)) { - return fail("Contracting dimension sizes do not match."); + for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size(); + ++i) { + const int64 lhs_contracting_dimension = + dimension_numbers.lhs_contracting_dimensions(i); + const int64 rhs_contracting_dimension = + dimension_numbers.rhs_contracting_dimensions(i); + if (lhs.dimensions(lhs_contracting_dimension) != + rhs.dimensions(rhs_contracting_dimension) || + lhs.is_dynamic_dimension(lhs_contracting_dimension) != + rhs.is_dynamic_dimension(rhs_contracting_dimension)) { + return fail("Contracting dimension sizes do not match."); + } } // Check that number of batch dimensions match. @@ -669,11 +664,12 @@ Status ValidateDotDimensionNumbers( // Check that batch dimension numbers and sizes match. for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { - if (dimension_numbers.lhs_batch_dimensions(i) != - dimension_numbers.rhs_batch_dimensions(i) || - lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != - rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { - return fail("Batch dimension numbers and sizes must match for lhs/rhs."); + if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i)) || + lhs.is_dynamic_dimension(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.is_dynamic_dimension( + dimension_numbers.rhs_batch_dimensions(i))) { + return fail("Batch dimension sizes must match for lhs/rhs."); } } @@ -683,21 +679,29 @@ Status ValidateDotDimensionNumbers( // Generate the result dimensions in order, rhs dimensions followed by lhs // dimensions except the contracted and batch dimensions. std::vector dimensions; - std::unordered_set rhs_batch_dims( - dimension_numbers.rhs_batch_dimensions().begin(), - dimension_numbers.rhs_batch_dimensions().end()); - for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) { - if (i != lhs_contracting_dimension) { + std::vector is_dynamic; + for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) { + dimensions.push_back(lhs.dimensions(lhs_dim)); + is_dynamic.push_back(lhs.is_dynamic_dimension(lhs_dim)); + } + for (int64 i = 0; i < lhs.rank(); i++) { + if (!absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(), + i) && + !absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) { dimensions.push_back(lhs.dimensions(i)); + is_dynamic.push_back(lhs.is_dynamic_dimension(i)); } } - for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) { - if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) { + for (int64 i = 0; i < rhs.rank(); i++) { + if (!absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(), + i) && + !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) { dimensions.push_back(rhs.dimensions(i)); + is_dynamic.push_back(rhs.is_dynamic_dimension(i)); } } Shape result = ShapeUtil::MakeShape( - ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions); + ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); @@ -708,20 +712,24 @@ Status ValidateDotDimensionNumbers( ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& lhs, const Shape& rhs) { - TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)); + TF_RET_CHECK(lhs.rank() == rhs.rank()); // The shapes have to be compatible. That is, if some dimension d has a // different size in the two shapes, one of them has to be 1 (a "degenerate" // dimension). In that case, the output shape has the non-1 dimension size // from the lhs/rhs pair in every index. - std::vector output_dimensions(ShapeUtil::Rank(lhs)); - for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) { + std::vector output_dimensions(lhs.rank()); + std::vector output_dimensions_is_dynamic(lhs.rank()); + for (int64 i = 0; i < lhs.rank(); ++i) { if (lhs.dimensions(i) == rhs.dimensions(i)) { output_dimensions[i] = lhs.dimensions(i); + output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i); } else if (lhs.dimensions(i) == 1) { output_dimensions[i] = rhs.dimensions(i); + output_dimensions_is_dynamic[i] = rhs.is_dynamic_dimension(i); } else if (rhs.dimensions(i) == 1) { output_dimensions[i] = lhs.dimensions(i); + output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i); } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", @@ -730,7 +738,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), - output_dimensions); + output_dimensions, output_dimensions_is_dynamic); } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( @@ -743,13 +751,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("Automatic shape inference not supported: %s and %s", ShapeUtil::HumanString(smaller_shape), ShapeUtil::HumanString(larger_shape)); - } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { + } else if (broadcast_dimensions.size() != smaller_shape.rank()) { return InvalidArgument( "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " " lower-rank operand's rank is %d, size of broadcast_dimensions is " "%u.", - ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); + smaller_shape.rank(), broadcast_dimensions.size()); } // broadcast_dimensions is a sequence of dimensions; its length is equal to @@ -809,6 +817,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } int64 small_dimension_size = smaller_shape.dimensions(i); int64 large_dimension_size = larger_shape.dimensions(dimension_to_match); + bool small_is_dynamic = smaller_shape.is_dynamic_dimension(i); + bool large_is_dynamic = + larger_shape.is_dynamic_dimension(dimension_to_match); // Dimension sizes must be compatible: match or be degenerate (degenerate // case is handled by degenerate dimension broadcasting which occurs after // InDim broadcasting). @@ -820,6 +831,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(smaller_shape), ShapeUtil::HumanString(larger_shape)); } + if (small_is_dynamic != large_is_dynamic) { + if ((small_dimension_size == 1 && !small_is_dynamic) || + (large_dimension_size == 1 && !large_is_dynamic)) { + // Do nothing. It's OK when the size-1 dimension is not static. + } else { + return InvalidArgument( + "Broadcast dimension %d dynamism mismatch: %s and %s.", i, + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); + } + } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { @@ -829,6 +851,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } output_shape.set_dimensions(dimension_to_match, small_dimension_size); + output_shape.set_dynamic_dimension(dimension_to_match, small_is_dynamic); } return output_shape; @@ -847,8 +870,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(rhs)); } - if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { - std::vector identity_dims(ShapeUtil::Rank(lhs)); + if (lhs.rank() == rhs.rank()) { + std::vector identity_dims(lhs.rank()); std::iota(identity_dims.begin(), identity_dims.end(), 0); if (!broadcast_dimensions.empty() && broadcast_dimensions != identity_dims) { @@ -865,15 +888,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs)); } - if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { + if (lhs.rank() == rhs.rank()) { return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs); } else { // Ranks do not match, so perform InDim broadcasting using // broadcast_dimensions. Scalar broadcasting is a special case of this. - const Shape& larger_shape = - ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs; - const Shape& smaller_shape = - ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs; + const Shape& larger_shape = lhs.rank() > rhs.rank() ? lhs : rhs; + const Shape& smaller_shape = lhs.rank() > rhs.rank() ? rhs : lhs; // After InDim broadcasting, perform degenerate dimensions broadcasting. TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape, @@ -942,6 +963,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, broadcast_dimensions)); if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); + } else if (lhs.element_type() == F64 && rhs.element_type() == F64) { + return ShapeUtil::ChangeElementType(shape, C128); } else { return Unimplemented("Complex component type is not implemented."); } @@ -1162,12 +1185,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == Status::OK()); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } if (feature_index < 0) { @@ -1177,25 +1200,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, feature_index); } - if (ShapeUtil::Rank(operand_shape) < 1) { + if (operand_shape.rank() < 1) { return InvalidArgument( "Expected the rank of operand to " "batch-norm-training to be at least 1; got %d.", - ShapeUtil::Rank(operand_shape)); + operand_shape.rank()); } - if (ShapeUtil::Rank(offset_shape) != 1) { + if (offset_shape.rank() != 1) { return InvalidArgument( "Offset input of batch-norm-training must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(offset_shape)); + offset_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-training must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1272,12 +1295,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == Status::OK()); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } if (feature_index < 0) { @@ -1287,25 +1310,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, feature_index); } - if (ShapeUtil::Rank(operand_shape) < 1) { + if (operand_shape.rank() < 1) { return InvalidArgument( "Expected the rank of operand to " "batch-norm-inference to be at least 1; got %d.", - ShapeUtil::Rank(operand_shape)); + operand_shape.rank()); } - if (ShapeUtil::Rank(offset_shape) != 1) { + if (offset_shape.rank() != 1) { return InvalidArgument( "Offset input of batch-norm-inference must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(offset_shape)); + offset_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-inference must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1417,41 +1440,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape)); - if (feature_index >= ShapeUtil::Rank(operand_shape)) { + if (feature_index >= operand_shape.rank()) { return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " "got feature_index %d, and rank %d.", - feature_index, ShapeUtil::Rank(operand_shape)); + feature_index, operand_shape.rank()); } - if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) { + if (operand_shape.rank() != output_grad_shape.rank()) { return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" " output_grad_shape; got rank(oprand_shape) %d, and" " rank(output_grad_shape) %d.", - ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); + operand_shape.rank(), output_grad_shape.rank()); } - if (ShapeUtil::Rank(mean_shape) != 1) { + if (mean_shape.rank() != 1) { return InvalidArgument( "Mean input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(mean_shape)); + mean_shape.rank()); } - if (ShapeUtil::Rank(scale_shape) != 1) { + if (scale_shape.rank() != 1) { return InvalidArgument( "Scale input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(scale_shape)); + scale_shape.rank()); } - if (ShapeUtil::Rank(var_shape) != 1) { + if (var_shape.rank() != 1) { return InvalidArgument( "Var input of batch-norm-grad must have" " rank 1, but has rank %d.", - ShapeUtil::Rank(var_shape)); + var_shape.rank()); } if (!ShapeUtil::ElementIsFloating(operand_shape)) { @@ -1538,7 +1561,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } // Verify operand_shape and output_grad_shape have same bounds. - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + for (int64 i = 0; i < operand_shape.rank(); ++i) { if (ShapeUtil::GetDimension(operand_shape, i) != ShapeUtil::GetDimension(output_grad_shape, i)) { return InvalidArgument( @@ -1556,7 +1579,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, int64 feature_group_count, - const Window& window, const ConvolutionDimensionNumbers& dnums) { + int64 batch_group_count, const Window& window, + const ConvolutionDimensionNumbers& dnums) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); @@ -1565,6 +1589,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "feature_group_count must be a positive number, got %d", feature_group_count); } + + if (batch_group_count <= 0) { + return InvalidArgument( + "batch_group_count must be a positive number, got %d", + batch_group_count); + } + + if (batch_group_count > 1 && feature_group_count > 1) { + return InvalidArgument( + "both batch_group_count %d and feature_group_count %d cannot be " + "greater than 1", + batch_group_count, feature_group_count); + } + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s.", @@ -1595,12 +1633,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } const int num_dims = num_spatial_dims + 2; - if (ShapeUtil::Rank(lhs) != num_dims) { + if (lhs.rank() != num_dims) { return InvalidArgument( "The LHS argument to a convolution should have rank %d; lhs: %s.", num_dims, ShapeUtil::HumanString(lhs)); } - if (ShapeUtil::Rank(rhs) != num_dims) { + if (rhs.rank() != num_dims) { return InvalidArgument( "The RHS argument to a convolution should have rank %d; rhs: %s.", num_dims, ShapeUtil::HumanString(rhs)); @@ -1615,29 +1653,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, input_dnums[1] = dnums.input_feature_dimension(); std::copy(dnums.input_spatial_dimensions().begin(), dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2); - std::sort(input_dnums.begin(), input_dnums.end()); + absl::c_sort(input_dnums); std::vector window_dnums(num_dims); window_dnums[0] = dnums.kernel_input_feature_dimension(); window_dnums[1] = dnums.kernel_output_feature_dimension(); std::copy(dnums.kernel_spatial_dimensions().begin(), dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2); - std::sort(window_dnums.begin(), window_dnums.end()); + absl::c_sort(window_dnums); std::vector output_dnums(num_dims); output_dnums[0] = dnums.output_batch_dimension(); output_dnums[1] = dnums.output_feature_dimension(); std::copy(dnums.output_spatial_dimensions().begin(), dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2); - std::sort(output_dnums.begin(), output_dnums.end()); + absl::c_sort(output_dnums); std::vector expected_dnums(num_dims); std::iota(expected_dnums.begin(), expected_dnums.end(), 0); const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; }; - if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) || - !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) || - !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { + if (!absl::c_all_of(input_dnums, in_range) || + !absl::c_all_of(window_dnums, in_range) || + !absl::c_all_of(output_dnums, in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s.", dnums.DebugString()); @@ -1678,6 +1716,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); + if (batch_group_count > 1 && input_batch % kernel_output_features != 0) { + return InvalidArgument( + "Expected output feature dimension (value %d) to be divisible by " + "input_batch (value %d) for batch group count %d; " + "got (%s, %s)\n" + "Dimension numbers: {%s}.", + kernel_output_features, input_batch, batch_group_count, + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); + } + if (input_features % feature_group_count != 0 || input_features / feature_group_count != kernel_input_features) { return InvalidArgument( @@ -1700,6 +1749,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } + + if (input_batch % batch_group_count > 0) { + return InvalidArgument( + "Expected input batch dimension (value %d) to be divisible by " + "batch_group_count (value %d); " + "got (%s, %s)\n" + "Dimension numbers: {%s}.", + input_batch, batch_group_count, ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs), dnums.DebugString()); + } + std::vector window_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { window_dims[i] = window.dimensions(i).size(); @@ -1722,14 +1782,39 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /*allow_negative_padding=*/true)); std::vector dimensions(num_dims); - dimensions[dnums.output_batch_dimension()] = input_batch; + dimensions[dnums.output_batch_dimension()] = input_batch / batch_group_count; dimensions[dnums.output_feature_dimension()] = kernel_output_features; for (int i = 0; i < num_spatial_dims; ++i) { dimensions[dnums.output_spatial_dimensions(i)] = window_output_shape.dimensions(i); } + std::vector is_dynamic(num_dims); + for (int i = 0; i < num_dims; i++) { + if (lhs.is_dynamic_dimension(i)) { + if (i == dnums.input_batch_dimension()) { + is_dynamic[dnums.output_batch_dimension()] = true; + } else if (i == dnums.input_feature_dimension()) { + // Input feature dimension is a contracting dimension, which does not + // affect the output dimension size. So we need to do nothing. + } else { + return InvalidArgument( + "Dynamic Spatial Convolution is not supported: lhs shape is %s ", + lhs.ToString()); + } + } + if (rhs.is_dynamic_dimension(i)) { + if (i == dnums.kernel_input_feature_dimension()) { + // Kernel feature dimension does not affect the output dimension size. + // So we need to do nothing. + } else { + return InvalidArgument( + "Dynamic Spatial Convolution is not supported: rhs shape is %s ", + rhs.ToString()); + } + } + } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), - dimensions); + dimensions, is_dynamic); } /* static */ StatusOr ShapeInference::InferFftShape( @@ -1750,7 +1835,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case FFT: case IFFT: if (in.element_type() != C64) { - return InvalidArgument("%s requires C64 input type, found %s.", + return InvalidArgument("%s requires complex input type, found %s.", FftType_Name(fft_type), PrimitiveType_Name(in.element_type())); } @@ -1814,7 +1899,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, #undef RET_CHECK_RANK } -/* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( +/* static */ StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( @@ -1834,12 +1919,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& shape, int64 split_dimension, int64 concat_dimension, int64 split_count) { TF_RET_CHECK(split_count > 0); - if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { + if (split_dimension >= shape.rank() || split_dimension < 0) { return InvalidArgument( "AllToAll split_dimension %d is out-of-bounds in shape %s.", split_dimension, ShapeUtil::HumanString(shape)); } - if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { + if (concat_dimension >= shape.rank() || concat_dimension < 0) { return InvalidArgument( "AllToAll concat_dimension %d is out-of-bounds in shape %s.", concat_dimension, ShapeUtil::HumanString(shape)); @@ -1877,7 +1962,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferCollectivePermuteShape( const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsArray(shape)); + TF_RET_CHECK(shape.IsArray()); return shape; } @@ -1901,7 +1986,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64 i = 1; i < num_reduced_args; ++i) { if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { return InvalidArgument( - "All reduced tensors must have the sime dimension. Tensor 0 has " + "All reduced tensors must have the same dimension. Tensor 0 has " "shape %s, Tensor %d has shape %s", ShapeUtil::HumanString(*reduced_args[0]), i, ShapeUtil::HumanString(*reduced_args[i])); @@ -1913,7 +1998,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // doesn't matter which one we choose. const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { - if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { + if (dimension >= arg.rank() || dimension < 0) { return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.", dimension, ShapeUtil::HumanString(arg)); } @@ -1930,20 +2015,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, std::set dimensions_to_reduce_set(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); std::vector new_dimensions; - for (int i = 0; i < ShapeUtil::Rank(arg); ++i) { + std::vector new_is_dynamic; + for (int i = 0; i < arg.rank(); ++i) { if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) { new_dimensions.push_back(arg.dimensions(i)); + new_is_dynamic.push_back(arg.is_dynamic_dimension(i)); } } if (ShapeUtil::IsScalar(to_apply.result())) { return ShapeUtil::MakeShape(to_apply.result().element_type(), - new_dimensions); + new_dimensions, new_is_dynamic); } else { std::vector result_subshapes; for (const Shape& subshape : to_apply.result().tuple_shapes()) { - result_subshapes.push_back( - ShapeUtil::MakeShape(subshape.element_type(), new_dimensions)); + result_subshapes.push_back(ShapeUtil::MakeShape( + subshape.element_type(), new_dimensions, new_is_dynamic)); } return ShapeUtil::MakeTupleShape(result_subshapes); } @@ -2017,12 +2104,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(source_shape), ShapeUtil::HumanString(window_result_shape)); } + return operand_shape; } /* static */ StatusOr ShapeInference::InferGetDimensionSizeShape( const Shape& shape, int64 dimension) { - if (dimension < 0 || dimension >= ShapeUtil::Rank(shape)) { + if (dimension < 0 || dimension >= shape.rank()) { return InvalidArgument("GetDimensionSize dimension out of bounds: %d.", dimension); } @@ -2064,10 +2152,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, starts.size(), strides.size())); } - if (starts.size() != ShapeUtil::Rank(arg)) { + if (starts.size() != arg.rank()) { return InvalidArgument( "Slice index count does not match argument rank: %u vs %d.", - starts.size(), ShapeUtil::Rank(arg)); + starts.size(), arg.rank()); } std::vector sizes; @@ -2102,41 +2190,87 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferDynamicSliceShape( - const Shape& operand_shape, const Shape& start_indices_shape, - absl::Span slice_sizes) { + const Shape& operand_shape, absl::Span start_index_shapes, + absl::Span slice_sizes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); - TF_RETURN_IF_ERROR( - ExpectArray(start_indices_shape, "start indices of dynamic slice")); + auto number_of_indices = start_index_shapes.size(); + // TODO(b/118437727): Remove this path. + if (!allow_scalar_indices || + (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) { + if (number_of_indices != 1) { + return InvalidArgument( + "Dynamic slice should have exactly 1 index operand, has %d.", + number_of_indices); + } - VLOG(2) << StrFormat( - "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", - ShapeUtil::HumanString(operand_shape), - ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", ")); + const Shape& start_indices_shape = start_index_shapes[0]; + VLOG(2) << StrFormat( + "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + StrJoin(slice_sizes, ", ")); - if (ShapeUtil::Rank(start_indices_shape) != 1) { - return InvalidArgument( - "Dynamic slice start indices of rank %d must be rank1.", - ShapeUtil::Rank(start_indices_shape)); - } + TF_RETURN_IF_ERROR( + ExpectArray(start_indices_shape, "start indices of dynamic slice")); - if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { - return InvalidArgument( - "Dynamic slice start indices must be of integral type."); - } + if (start_indices_shape.rank() != 1) { + return InvalidArgument( + "Dynamic slice start indices of rank %d must be rank1.", + start_indices_shape.rank()); + } - const int64 start_num_dims = start_indices_shape.dimensions(0); - if (ShapeUtil::Rank(operand_shape) != start_num_dims) { - return InvalidArgument( - "Dynamic slice start number of dimensions %d (%s) must match rank " - "%d of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { + return InvalidArgument( + "Dynamic slice start indices must be of integral type."); + } + + const int64 start_num_dims = start_indices_shape.dimensions(0); + if (operand_shape.rank() != start_num_dims) { + return InvalidArgument( + "Dynamic slice start number of dimensions %d (%s) must match rank " + "%d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); + } + } else { + VLOG(2) << StrFormat("slicing shape %s a with slice_sizes={%s}", + ShapeUtil::HumanString(operand_shape), + StrJoin(slice_sizes, ", ")); + + if (operand_shape.rank() != number_of_indices) { + return InvalidArgument( + "Dynamic slice start number of dimensions %d must match rank " + "%d of slice input (%s).", + number_of_indices, operand_shape.rank(), + ShapeUtil::HumanString(operand_shape)); + } + + if (number_of_indices > 0) { + const Shape& first_index_shape = start_index_shapes[0]; + if (!ShapeUtil::IsScalar(first_index_shape)) { + return InvalidArgument("Dynamic slice indices must be scalar, not %s.", + ShapeUtil::HumanString(first_index_shape)); + } + if (!ShapeUtil::ElementIsIntegral(first_index_shape)) { + return InvalidArgument( + "Dynamic slice start indices must be of integral type."); + } + for (const Shape& index_shape : start_index_shapes) { + if (!ShapeUtil::Compatible(first_index_shape, index_shape)) { + return InvalidArgument( + "Dynamic slice start indices must all have the same shape, got " + "mismatching indices with shapes %s and %s.", + ShapeUtil::HumanString(first_index_shape), + ShapeUtil::HumanString(index_shape)); + } + } + } } - if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { + if (slice_sizes.size() != operand_shape.rank()) { return InvalidArgument( "Dynamic slice index count does not match argument rank: %u vs %d.", - slice_sizes.size(), ShapeUtil::Rank(operand_shape)); + slice_sizes.size(), operand_shape.rank()); } for (int64 dim = 0; dim < slice_sizes.size(); ++dim) { @@ -2159,46 +2293,92 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, - const Shape& start_indices_shape) { + absl::Span start_index_shapes, bool allow_scalar_indices) { TF_RETURN_IF_ERROR( ExpectArray(operand_shape, "operand of dynamic update slice")); TF_RETURN_IF_ERROR( ExpectArray(update_shape, "update of dynamic update slice")); - TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, - "start indices of dynamic update slice")); - VLOG(2) << StrFormat( - "updating slice of shape %s at dynamic start_indices %s with update " - "shape %s", - ShapeUtil::HumanString(operand_shape), - ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::HumanString(update_shape)); + auto number_of_indices = start_index_shapes.size(); + // TODO(b/118437727): Remove this path. + if (!allow_scalar_indices || + (number_of_indices >= 1 && start_index_shapes[0].rank() == 1)) { + if (number_of_indices != 1) { + return InvalidArgument( + "Dynamic update slice should have exactly 1 index operand, has %d.", + number_of_indices); + } + const Shape& start_indices_shape = start_index_shapes[0]; + TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, + "start indices of dynamic update slice")); - if (ShapeUtil::Rank(start_indices_shape) != 1) { - return InvalidArgument( - "Dynamic update slice start indices of rank %d must be rank1.", - ShapeUtil::Rank(start_indices_shape)); - } + VLOG(2) << StrFormat( + "updating slice of shape %s at dynamic start_indices %s with update " + "shape %s", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::HumanString(update_shape)); - if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { - return InvalidArgument( - "Dynamic update slice start indices must be of integral type."); - } + if (start_indices_shape.rank() != 1) { + return InvalidArgument( + "Dynamic update slice start indices of rank %d must be rank1.", + start_indices_shape.rank()); + } - const int64 start_num_dims = start_indices_shape.dimensions(0); - if (ShapeUtil::Rank(operand_shape) != start_num_dims) { - return InvalidArgument( - "Dynamic update slice start number of dimensions %d (%s) must match " - "rank %d of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape), - ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must be of integral type."); + } + + const int64 start_num_dims = start_indices_shape.dimensions(0); + if (operand_shape.rank() != start_num_dims) { + return InvalidArgument( + "Dynamic update slice start number of dimensions %d (%s) must match " + "rank %d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); + } + } else { + VLOG(2) << StrFormat("updating slice of shape %s with update shape %s", + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(update_shape)); + + if (operand_shape.rank() != number_of_indices) { + return InvalidArgument( + "Dynamic update slice start number of dimensions %d must match rank " + "%d of slice input (%s).", + number_of_indices, operand_shape.rank(), + ShapeUtil::HumanString(operand_shape)); + } + + if (number_of_indices > 0) { + const Shape& first_index_shape = start_index_shapes[0]; + if (!ShapeUtil::IsScalar(first_index_shape)) { + return InvalidArgument( + "Dynamic update slice indices must be scalar, not %s.", + ShapeUtil::HumanString(first_index_shape)); + } + if (!ShapeUtil::ElementIsIntegral(first_index_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must be of integral type."); + } + for (const Shape& index_shape : start_index_shapes) { + if (!ShapeUtil::Compatible(first_index_shape, index_shape)) { + return InvalidArgument( + "Dynamic update slice start indices must all have the same " + "shape, got mismatching indices with shapes %s and %s.", + ShapeUtil::HumanString(first_index_shape), + ShapeUtil::HumanString(index_shape)); + } + } + } } - if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { + if (update_shape.rank() != operand_shape.rank()) { return InvalidArgument( "Dynamic update slice update rank does not match argument rank: " "%d vs %d.", - ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); + update_shape.rank(), operand_shape.rank()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, @@ -2210,7 +2390,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, PrimitiveType_Name(update_shape.element_type())); } - for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) { + for (int64 dim = 0; dim < operand_shape.rank(); ++dim) { const int64 input_dim_size = operand_shape.dimensions(dim); const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { @@ -2236,7 +2416,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("a dimension number is duplicated in reverse"); } for (int64 dimension : dimensions) { - if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { + if (dimension >= operand_shape.rank() || dimension < 0) { return InvalidArgument( "One of the reverse dimensions (%d) is out-of-bounds in shape %s.", dimension, ShapeUtil::HumanString(operand_shape)); @@ -2247,7 +2427,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferGetTupleElementShape( const Shape& arg, int64 index) { - if (!ShapeUtil::IsTuple(arg)) { + if (!arg.IsTuple()) { return InvalidArgument( "Cannot infer shape: attempting to index into non-tuple: %s.", ShapeUtil::HumanString(arg)); @@ -2283,7 +2463,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, }; // Check the shapes of computation parameters and return types. - if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { + if (!ShapeUtil::Equal(condition.result(), ShapeUtil::MakeShape(PRED, {}))) { return InvalidArgument("Condition must return a boolean; got %s.", shape_string()); } @@ -2303,7 +2483,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& predicate, const Shape& true_operand, const Shape& false_operand, const ProgramShape& true_computation, const ProgramShape& false_computation) { - if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { + if (!ShapeUtil::Equal(predicate, ShapeUtil::MakeShape(PRED, {}))) { return InvalidArgument("Predicate must be a boolean; got %s.", ShapeUtil::HumanString(predicate)); } @@ -2378,8 +2558,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast")); TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast")); - const int64 operand_rank = ShapeUtil::Rank(operand_shape); - const int64 output_rank = ShapeUtil::Rank(output_shape); + const int64 operand_rank = operand_shape.rank(); + const int64 output_rank = output_shape.rank(); if (operand_rank > output_rank) { return InvalidArgument( "InDim style broadcast must be to an equal or higher ranked shape; " @@ -2407,6 +2587,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, i, operand_shape.dimensions(i), broadcast_dimensions[i], output_shape.dimensions(broadcast_dimensions[i])); } + if (operand_shape.is_dynamic_dimension(i) != + output_shape.is_dynamic_dimension(broadcast_dimensions[i])) { + return InvalidArgument( + "Broadcast input and output dynamism mismatch: %s and %s", + operand_shape.ToString(), output_shape.ToString()); + } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) { @@ -2438,9 +2624,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(inferred_shape)); } - std::vector indices(ShapeUtil::Rank(operand)); + std::vector indices(operand.rank()); std::iota(indices.begin(), indices.end(), 0); - if (dimensions.size() != ShapeUtil::Rank(operand) || + if (dimensions.size() != operand.rank() || !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( @@ -2449,6 +2635,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(operand, inferred_shape); + for (auto& unmodified : unmodified_dims) { + if (operand.is_dynamic_dimension(unmodified.first)) { + inferred_shape.set_dynamic_dimension(unmodified.second, true); + } + } + return inferred_shape; } @@ -2456,9 +2650,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); - std::vector indices(ShapeUtil::Rank(operand)); + std::vector indices(operand.rank()); std::iota(indices.begin(), indices.end(), 0); - if (dimensions.size() != ShapeUtil::Rank(operand) || + if (dimensions.size() != operand.rank() || !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( @@ -2522,19 +2716,31 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "Select's pred operand must have PRED element type; got %s.", ShapeUtil::HumanString(pred)); } - if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || + if (Shape::Equal() + .IgnoreElementType() + .IgnoreLayout() + .IgnoreDynamicDimension()(pred, on_true) || ShapeUtil::IsScalar(pred)) { // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. - return ShapeUtil::ChangeElementType( + Shape inferred_shape = ShapeUtil::ChangeElementType( on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); - } else { - return InvalidArgument( - "Select operation with non-scalar predicate with dimensionality " - " different from the other operands: %s.", - ShapeUtil::HumanString(pred)); + + // Propagate dynamic dimensions if pred is not a scalar. + if (!ShapeUtil::IsScalar(pred)) { + for (int i = 0; i < inferred_shape.rank(); i++) { + if (pred.is_dynamic_dimension(i)) { + inferred_shape.set_dynamic_dimension(i, true); + } + } + } + return inferred_shape; } + return InvalidArgument( + "Select operation with non-scalar predicate with dimensionality " + "different from the other operands: %s.", + ShapeUtil::HumanString(pred)); } /* static */ StatusOr ShapeInference::InferTupleSelectShape( @@ -2810,7 +3016,7 @@ Status ValidateScatterDimensionNumbers( "update_window_dims in scatter op must not repeat; got: %s.", StrJoin(dim_numbers.update_window_dims(), ", ")); } - const int64 updates_rank = ShapeUtil::Rank(updates_shape); + const int64 updates_rank = updates_shape.rank(); for (int64 window_dim : dim_numbers.update_window_dims()) { if (window_dim < 0 || window_dim >= updates_rank) { return InvalidArgument( @@ -2844,10 +3050,10 @@ Status ValidateScatterDimensionNumbers( // Validate window size. auto window_size = dim_numbers.update_window_dims_size() + dim_numbers.inserted_window_dims_size(); - if (window_size != ShapeUtil::Rank(operand_shape)) { + if (window_size != operand_shape.rank()) { return InvalidArgument( "Scatter op has window of size %d; doesn't match operand of rank %d.", - window_size, ShapeUtil::Rank(operand_shape)); + window_size, operand_shape.rank()); } // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. @@ -2932,10 +3138,9 @@ Status ValidateScatterDimensionNumbers( int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + scatter_dim_numbers.update_window_dims_size(); - if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { + if (updates_shape.rank() != expected_updates_rank) { return InvalidArgument("Updates tensor must be of rank %d; got %d.", - expected_updates_rank, - ShapeUtil::Rank(updates_shape)); + expected_updates_rank, updates_shape.rank()); } TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers( @@ -2966,7 +3171,7 @@ Status ValidateScatterDimensionNumbers( } int64 scatter_dims_seen = 0; - for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { + for (int64 i = 0; i < updates_shape.rank(); ++i) { bool is_update_window_dim = absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i); if (is_update_window_dim) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index d94385a04d50baff8156570a09620fd458547936..7d39ef38e05abf0a81683c1fb0f3999908b27d23 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -109,7 +109,7 @@ class ShapeInference { // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( const Shape& lhs, const Shape& rhs, int64 feature_group_count, - const Window& window, + int64 batch_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); // Infers the shape produced by the given FFT type on the given operand. @@ -118,7 +118,7 @@ class ShapeInference { // Infers the shape produced by a cross replica sum with the given operand // shapes. - static StatusOr InferCrossReplicaSumShape( + static StatusOr InferAllReduceShape( absl::Span operand_shapes); // Infers final shape of an Alltoall operation that is created by the xla @@ -176,14 +176,15 @@ class ShapeInference { // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static StatusOr InferDynamicSliceShape( - const Shape& operand_shape, const Shape& start_indices_shape, - absl::Span slice_sizes); + const Shape& operand_shape, absl::Span start_index_shapes, + absl::Span slice_sizes, bool allow_scalar_indices = true); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. static StatusOr InferDynamicUpdateSliceShape( const Shape& operand_shape, const Shape& update_shape, - const Shape& start_indices_shape); + absl::Span start_index_shapes, + bool allow_scalar_indices = true); // Infers the shape produced by doing a compile-time-constant indexing into // the given input shape. This is essential for operations on tuples, because diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 4639e32db4d59080a9e85e46983fac61d9e76be9..26120a06b823c9fddf378991cec434a880fb888d 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test { protected: // Some handy scalar shapes. const Shape s32_ = ShapeUtil::MakeShape(S32, {}); + const Shape f16_ = ShapeUtil::MakeShape(F16, {}); const Shape f32_ = ShapeUtil::MakeShape(F32, {}); const Shape f64_ = ShapeUtil::MakeShape(F64, {}); const Shape pred_ = ShapeUtil::MakeShape(PRED, {}); @@ -260,8 +261,8 @@ TEST_F(ShapeInferenceTest, Complex) { ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok()); // Component types must match. ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok()); - // Only F32->C64 supported. - ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok()); + // Only F32->C64 and F64->C128 supported. + ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok()); // Validate correct uses. Shape c64_32 = ShapeUtil::MakeShape(C64, {32}); TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {})); @@ -285,6 +286,9 @@ TEST_F(ShapeInferenceTest, Complex) { ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); + + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {}))); } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { @@ -420,7 +424,8 @@ TEST_F(ShapeInferenceTest, Convolve) { dim1->set_window_dilation(1); dim1->set_base_dilation(1); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), @@ -465,7 +470,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dim1->set_window_dilation(2); dim1->set_base_dilation(1); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), @@ -510,7 +516,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dim1->set_window_dilation(1); dim1->set_base_dilation(2); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), @@ -548,7 +555,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dim1->set_padding_low(1); dim1->set_padding_high(1); auto inferred_status = ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, + window, dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("each dimension exactly once")); @@ -1002,9 +1010,9 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Batch and contracting dimension number mismatch")); + EXPECT_TRUE(inferred_status.ok()); + EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), + ShapeUtil::MakeShape(F32, {32, 32, 64}))); } // vector vector -> scalar @@ -1096,7 +1104,6 @@ TEST_F(ShapeInferenceTest, DotGeneral) { TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); - Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(2); @@ -1110,8 +1117,28 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Must specify one contracting dimension for both " - "lhs and rhs")); + HasSubstr("Must specify the same number of contracting " + "dimensions for lhs and rhs.")); +} + +TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + EXPECT_TRUE(inferred_status.ok()); + EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape)); } // BatchMatMul with different batch dimension sizes fails. @@ -1130,11 +1157,11 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Batch dimension numbers and sizes must match")); + HasSubstr("Batch dimension sizes must match")); } -// BatchMatMul with different batch dimension numbers fails. -TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { +// BatchMatMul with different batch dimension numbers passes +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersPasses) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); @@ -1147,9 +1174,9 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { auto inferred_status = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); - ASSERT_FALSE(inferred_status.ok()); - ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("Batch dimension numbers must precede non-batch")); + ASSERT_TRUE(inferred_status.ok()); + ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), + ShapeUtil::MakeShape(F32, {2, 11, 14}))); } // BatchMatMul with out-of-range dimension numbers fails. diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 28a30b5ee2dbcb5012804578d4d037c241045309..d90dde3b13d3aa9e1de10dd9e1d11a8e6da170de 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -85,7 +85,7 @@ string ShapedBuffer::ToString() const { on_device_shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { string shape_str; - if (ShapeUtil::IsTuple(subshape)) { + if (subshape.IsTuple()) { shape_str = "tuple"; } else { shape_str = ShapeUtil::HumanStringWithLayout(subshape); diff --git a/tensorflow/compiler/xla/service/sort_simplifier.cc b/tensorflow/compiler/xla/service/sort_simplifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a00e8d7b227f14d462ca53f695189f3f48754ee --- /dev/null +++ b/tensorflow/compiler/xla/service/sort_simplifier.cc @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/sort_simplifier.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" + +namespace xla { +namespace { + +// If the sort instruction has a tuple shape then looks for unused output +// values and removes them from the sort instruction. Returns true if the +// graph has been modified. +StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { + if (!sort->shape().IsTuple()) { + return false; + } + + HloComputation* computation = sort->parent(); + + if (computation->root_instruction() == sort) { + // Can't analyse users of the root instruction. + return false; + } + + // Index 0 is the sorting key used by the sort HLO itself. + absl::flat_hash_set used_indices{0}; + for (const HloInstruction* user : sort->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + // Can't analyse users other then get-tuple-element. + return false; + } + used_indices.insert(user->tuple_index()); + } + + if (used_indices.size() == sort->operand_count()) { + // All operands are used. + return false; + } + + std::vector operands{sort->mutable_operand(0)}; + std::vector new_shapes{sort->operand(0)->shape()}; + for (int64 i = 1; i < sort->operand_count(); ++i) { + if (used_indices.count(i)) { + operands.push_back(sort->mutable_operand(i)); + new_shapes.push_back(sort->operand(i)->shape()); + } + } + + Shape new_sort_shape = new_shapes.size() == 1 + ? new_shapes[0] + : ShapeUtil::MakeTupleShape(new_shapes); + HloInstruction* new_sort = computation->AddInstruction( + sort->CloneWithNewOperands(new_sort_shape, operands)); + + // Map from original get-tuple-element tuple index to new HLO instruction + absl::flat_hash_map result_map; + if (new_sort->shape().IsTuple()) { + // Old sort key maps to new sort key. + int64 new_index = 0; + for (int64 i = 0; i < sort->operand_count(); ++i) { + if (used_indices.count(i)) { + result_map[i] = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_shapes[new_index], new_sort, new_index)); + ++new_index; + } + } + } else { + result_map[0] = new_sort; + } + std::vector users(sort->users().begin(), + sort->users().end()); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR( + user->ReplaceAllUsesWith(result_map.at(user->tuple_index()))); + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(user)); + } + return true; +} +} // namespace + +StatusOr SortSimplifier::Run(HloModule* module) { + VLOG(2) << "HLO module before SortSimplifier:"; + XLA_VLOG_LINES(2, module->ToString()); + + bool changed = false; + std::vector sort_instrs; + for (auto* comp : module->MakeNonfusionComputations()) { + absl::c_copy_if(comp->instructions(), std::back_inserter(sort_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kSort; + }); + } + + for (HloInstruction* sort_instr : sort_instrs) { + TF_ASSIGN_OR_RETURN(bool result, RemoveUnusedOperandFromSort(sort_instr)); + changed |= result; + } + + if (changed) { + VLOG(2) << "HLO module after SortSimplifier:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after SortSimplifier"; + } + + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/sort_simplifier.h b/tensorflow/compiler/xla/service/sort_simplifier.h new file mode 100644 index 0000000000000000000000000000000000000000..8c6f313aa04f51e14a14450bc72fc622d74133a4 --- /dev/null +++ b/tensorflow/compiler/xla/service/sort_simplifier.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SORT_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SORT_SIMPLIFIER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which removes unused operands from sort, where an unused operand is +// defined as an operand at some index 'x' at which the output is not used. +class SortSimplifier : public HloModulePass { + public: + absl::string_view name() const override { return "simplify-sorts"; } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SORT_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/sort_simplifier_test.cc b/tensorflow/compiler/xla/service/sort_simplifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cd05fcf830d32e8bac4f8b260d3dd143ab98ad7b --- /dev/null +++ b/tensorflow/compiler/xla/service/sort_simplifier_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/sort_simplifier.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace m = match; + +using SortSimplifierTest = HloTestBase; + +TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1} + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + uint64 num_executions = 0; + do { + num_executions++; + } while (simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(num_executions, 2); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Sort(m::Parameter(0)))); +} + +TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,87] parameter(0) + values.0 = s32[64,87] parameter(1) + values.1 = u32[64,87] parameter(2) + sort = (f32[64,87], s32[64,87], u32[64,87]) sort( + keys, values.0, values.1), + dimensions={1} + gte.0 = f32[64,87] get-tuple-element(sort), index=0 + gte.1 = u32[64,87] get-tuple-element(sort), index=2 + ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + GmockMatch(m::Tuple( + m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 0), + m::GetTupleElement(m::Sort(m::Parameter(0), m::Parameter(2)), 1)))); +} + +TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} + ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index a21e586efadb85d18e88e44999283b28f7f65eac..15ef623cc7b2dbc31e9cba5c4783c39b8805a5aa 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -142,7 +142,7 @@ Status TransferManager::TransferArrayToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); - TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) + TF_RET_CHECK(on_device_shape.IsArray()) << "On-device representation of " << ShapeUtil::HumanString(literal.shape()) << " is not an array: " << ShapeUtil::HumanString(on_device_shape); @@ -227,7 +227,7 @@ Status TransferManager::WriteTupleIndexTablesAsync( return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_device_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { - if (ShapeUtil::IsTuple(device_subshape)) { + if (device_subshape.IsTuple()) { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); @@ -248,6 +248,22 @@ Status TransferManager::WriteTupleIndexTablesAsync( }); } +Status TransferManager::WriteRootTupleIndexTable( + se::Stream* stream, const ShapedBuffer& device_buffer) { + TF_RET_CHECK(device_buffer.on_device_shape().IsTuple()); + se::DeviceMemoryBase device_memory = device_buffer.buffer({}); + TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) == + device_memory.size()); + + std::vector elements; + for (int64 i = 0; + i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) { + elements.push_back(device_buffer.buffer({i})); + } + return WriteSingleTupleIndexTable( + stream, elements, device_buffer.on_device_shape(), &device_memory); +} + Status TransferManager::TransferBufferFromDevice( se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, void* destination) { diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 49f0b8f8b72001f07200d3e94828f60fcb0fa8fb..43a50487c636da75224547286a31625db3f91330 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -146,6 +146,12 @@ class TransferManager { Status WriteTupleIndexTablesAsync(se::Stream* stream, const ShapedBuffer& device_buffer); + // Writes a tuple index buffer for the root of 'device_buffer', which must + // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer, + // rather than writing all subbuffers. This method is always asynchronous. + Status WriteRootTupleIndexTable(se::Stream* stream, + const ShapedBuffer& device_buffer); + // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory // region for a host-to-device transfer. diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 7c1f4b5cc67dd2a84271b4f2b8015fdb2ff6e846..a95ca2bf2a8fcd700eb9234cafbfce9b62f2370c 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -45,7 +45,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( auto& operand = *dot.operand(i); if (operand.IsRank2Transpose()) { operand_set.push_back(i); - } else if (ShapeUtil::Rank(operand.shape()) != 2) { + } else if (operand.shape().rank() != 2) { return {}; } } @@ -130,8 +130,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { HloInstruction* new_lhs; const int64 kLhsIdx = 0; - if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) != - operand_indices.end()) { + if (absl::c_linear_search(operand_indices, kLhsIdx)) { HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx); const auto& transpose_dimensions = transpose.dimensions(); HloInstruction& transpose_operand = *transpose.mutable_operand(0); @@ -154,8 +153,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { HloInstruction* new_rhs; const int64 kRhsIdx = 1; - if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) != - operand_indices.end()) { + if (absl::c_linear_search(operand_indices, kRhsIdx)) { HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx); const auto& transpose_dimensions = transpose.dimensions(); HloInstruction& transpose_operand = *transpose.mutable_operand(0); @@ -178,7 +176,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(), - convolution.window(), new_dnums, convolution.precision_config()); + convolution.batch_group_count(), convolution.window(), new_dnums, + convolution.precision_config()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 17cdaa74fc328d156292f5af828d4222a9a01f1f..f8a5fa0215007310d6bec35d20fc643afc824dda 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -139,9 +139,9 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { HloModule FoldDotTransposeConstant ENTRY entry_computation { - constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } }) + constant = f32[2,1]{1,0} constant({ { 1 }, { 2 } }) transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0} - constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } }) + constant.1 = f32[3,2]{1,0} constant({ { 1, 2 }, { 3, 4 }, { 5, 6 } }) transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0} ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } @@ -240,12 +240,13 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, - dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = @@ -295,12 +296,13 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, - dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = @@ -355,12 +357,13 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, - dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = @@ -421,12 +424,13 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, - dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, + /*batch_group_count=*/1, window, dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, - /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); + /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 50d51eaeb762e208004c1dae3dcc27503f3f94e9..5e505aaf02f157d0cba9dff42b1a9b89a6691504 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -55,11 +56,10 @@ bool PointsToSet::IsAmbiguous() const { bool PointsToSet::IsDistinct() const { bool distinct = true; - std::set all_points_to; - ForEachElement([&distinct, &all_points_to](const ShapeIndex& /*index*/, - const BufferList& points_to) { + absl::flat_hash_set all_points_to; + ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) { for (auto& buffer : points_to) { - if (all_points_to.count(buffer) != 0) { + if (all_points_to.contains(buffer)) { distinct = false; } all_points_to.insert(buffer); @@ -87,9 +87,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const { bool found = false; ForEachElement([&found, &buffer](const ShapeIndex& /*index*/, const BufferList& pointed_to_buffers) { - if (!found && - std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), - &buffer) != pointed_to_buffers.end()) { + if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) { found = true; } }); @@ -99,8 +97,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const { bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer, const ShapeIndex& index) const { const auto& pointed_to_buffers = element(index); - return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), - &buffer) != pointed_to_buffers.end(); + return absl::c_linear_search(pointed_to_buffers, &buffer); } void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer, @@ -210,7 +207,7 @@ Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) { &logical_buffer_analysis_->GetBuffer(hlo_instruction, index)); }); - if (ShapeUtil::IsTuple(hlo_instruction->shape())) { + if (hlo_instruction->shape().IsTuple()) { // If the hlo instruction is a tuple-shaped, then trivially the instruction // itself is the source of the tuple. points_to_set.add_tuple_source({}, hlo_instruction); @@ -604,9 +601,8 @@ bool TuplePointsToAnalysis::DoesNotUseOperandBuffer( } else if (user->opcode() == HloOpcode::kFusion && user->fusion_kind() == HloInstruction::FusionKind::kLoop) { // Find fusion parameter associated with 'operand'. - auto it = std::find_if( - user->fused_parameters().begin(), user->fused_parameters().end(), - [=](HloInstruction* fused_param) { + auto it = absl::c_find_if( + user->fused_parameters(), [&](HloInstruction* fused_param) { return user->operand(fused_param->parameter_number()) == operand; }); CHECK(it != user->fused_parameters().end()); @@ -672,9 +668,8 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( } // Find fusion parameter associated with 'operand'. const auto& fused_params = fusion->fused_parameters(); - auto fused_param_it = std::find_if( - fused_params.begin(), fused_params.end(), - [&](HloInstruction* fused_param) { + auto fused_param_it = + absl::c_find_if(fused_params, [&](HloInstruction* fused_param) { return fusion->operand(fused_param->parameter_number()) == operand; }); if (fused_param_it == fused_params.end()) { @@ -743,11 +738,10 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( // Check if one operand of kAdd fused root is kDot or kConvolution. auto* add = user->fused_expression_root(); auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot; - }); + absl::c_find_if(add->operands(), [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); if (add_operand_it == add->operands().end()) { return false; } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 561762b5d424ed5f537665be9d67a81dc8bdd56e..fd5759e44230db8223822d6ae0f511027f73d8f9 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -623,7 +623,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { void Run(const bool add_additional_gte0_user) { Shape input_shape = ShapeUtil::MakeShape(F32, {8}); Shape update_shape = ShapeUtil::MakeShape(F32, {3}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, update_shape, starts_shape}); @@ -657,7 +657,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { HloInstruction::CreateGetTupleElement(starts_shape, tuple_param0, 2)); // Update 'input' with 'update' at dynamic 'starts' indices. builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - input_shape, input, update, starts)); + input_shape, input, update, {starts})); // Build computation and add it to module as entry computation. BuildModule(builder.Build()); @@ -721,9 +721,8 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { // to fusion 'operand'. HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion, HloInstruction* operand) { - auto it = std::find_if( - fusion->fused_instructions().begin(), - fusion->fused_instructions().end(), [=](const HloInstruction* fused) { + auto it = absl::c_find_if( + fusion->fused_instructions(), [&](const HloInstruction* fused) { return fused->opcode() == HloOpcode::kParameter && fusion->operand(fused->parameter_number()) == operand; }); @@ -734,7 +733,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { // Returns all users of 'fusion_paran' at 'tuple_index'. std::vector GetFusionParameterUsersAt( HloInstruction* fusion_param, int64 tuple_index) { - CHECK(ShapeUtil::IsTuple(fusion_param->shape())); + CHECK(fusion_param->shape().IsTuple()); std::vector users_at_tuple_index; for (auto user : fusion_param->users()) { CHECK_EQ(HloOpcode::kGetTupleElement, user->opcode()); @@ -883,12 +882,12 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -977,12 +976,12 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); + data_shape, gte1, update, {starts})); builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); @@ -1004,7 +1003,7 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); Shape update_shape = ShapeUtil::MakeShape(F32, {4}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {}); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); auto update = builder.AddInstruction( @@ -1012,7 +1011,7 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { auto starts = builder.AddInstruction( HloInstruction::CreateParameter(2, starts_shape, "starts")); auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); + data_shape, data, update, {starts})); BuildModuleAndRunAnalysis(builder.Build()); diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc index cfb0c787d09557fd1aec3517eb9698cfec323369..90ea79ec263a038556ccbd2cd345b337c5a5dcf3 100644 --- a/tensorflow/compiler/xla/service/tuple_util.cc +++ b/tensorflow/compiler/xla/service/tuple_util.cc @@ -21,7 +21,7 @@ namespace xla { /*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple, int64 elements) { - CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + CHECK(input_tuple->shape().IsTuple()); HloComputation* computation = input_tuple->parent(); const Shape& input_shape = input_tuple->shape(); @@ -41,7 +41,7 @@ namespace xla { /*static*/ HloInstruction* TupleUtil::AppendSuffix( HloInstruction* input_tuple, absl::Span trailing_values) { - CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + CHECK(input_tuple->shape().IsTuple()); HloComputation* computation = input_tuple->parent(); const Shape& input_shape = input_tuple->shape(); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index 68e2569f66bea9ec1223e454d1ead0efc7b9498e..c93a9ba3176002a34fe84a29e62075de4d19168f 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -301,7 +301,7 @@ optional ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) { /*dest_shape_index=*/{indvar_index}, /*src_shape_index=*/{})); StatusOr eval_result = - evaluator.Evaluate(*while_cond, {std::move(fake_input)}); + evaluator.Evaluate(*while_cond, {std::move(fake_input)}); if (!eval_result.ok()) { VLOG(2) << "Couldn't evaluate while loop condition."; diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 75d406435b6f58faecc86b82c33e9e2dd6bccbea..3bcf5c38309a86e9e3cab3268f3f065005f7a923 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -129,7 +129,7 @@ condition { ENTRY entry { const_0 = f32[2] constant({1, 2}) - const_1 = (f32[2], f32[2]) constant((f32[2], f32[2]) ({2, 1},{3,1})) + const_1 = (f32[2], f32[2]) constant(({2, 1},{3,1})) while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1) ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body } @@ -206,8 +206,8 @@ body { p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 - token = token[] after-all() - outfeed = token[] outfeed(p_body.0, token) + token0 = token[] after-all() + outfeed = token[] outfeed(p_body.0, token0) ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1) } @@ -305,7 +305,7 @@ condition { ENTRY entry { const_0 = f32[] constant(0) - const_1 = (f32[], f32[]) constant((f32[], f32[]) (1, 10)) + const_1 = (f32[], f32[]) constant((1, 10)) while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1) ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 41011176ffa91e885bc58364d1fb19617d3518ad..69cc8feb3f31ad782b9d3437d81d0ab8ce10aadb 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -89,7 +89,7 @@ static void CreateLoopInvariantCopy( HloInstruction* next_operand = frame->instruction->mutable_operand(frame->operand_index++); - if (hoisted_instructions->count(next_operand) || + if (hoisted_instructions->contains(next_operand) || next_operand == while_body_param) { continue; } @@ -127,7 +127,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); - if (!ShapeUtil::IsTuple(while_instr->shape())) { + if (!while_instr->shape().IsTuple()) { // This restriction leaves one interesting pattern on the table: // // while_body(f32[1024, 1024] %param) { @@ -168,7 +168,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( // is no benefit to hoisting them unless something that uses it is also // hoisted. for (auto* instr : WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { - if (ShapeUtil::IsArray(instr->shape())) { + if (instr->shape().IsArray()) { // TODO(b/79147885): We should try to generalize this to tuples for // uniformity's sake, if nothing else. InsertOrDie(&unhoisted_invariant_instructions, instr); @@ -221,7 +221,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( ShapeUtil::ForEachSubshape( operand->shape(), [&input_size](const Shape& subshape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { input_size += ShapeUtil::ByteSizeOfElements(subshape); } }); @@ -229,7 +229,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( ShapeUtil::ForEachSubshape( instruction->shape(), [&output_size](const Shape& subshape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(subshape)) { + if (subshape.IsArray()) { output_size += ShapeUtil::ByteSizeOfElements(subshape); } }); @@ -241,7 +241,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( auto is_invariant = [&](HloInstruction* op) { return hoisted_instructions.find(op) != hoisted_instructions.end() || - unhoisted_invariant_instructions.count(op) || + unhoisted_invariant_instructions.contains(op) || op->opcode() == HloOpcode::kConstant; }; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 8e7c4bc8828552e197b41f874c070d496b85a382..3587c016b4420163a607422b1acc838646fab83a 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -299,7 +299,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // bitcast either. auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1}); auto token_shape = ShapeUtil::MakeTokenShape(); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); @@ -314,10 +314,12 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); HloInstruction* in_token = builder.AddInstruction( HloInstruction::CreateGetTupleElement(token_shape, param, 2)); - HloInstruction* bitcast_inst = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); - HloInstruction* out_token = builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, "")); + HloInstruction* bitcast_inst = + builder.AddInstruction(HloInstruction::CreateUnary( + effective_scalar_s32, HloOpcode::kBitcast, gte_0)); + HloInstruction* out_token = + builder.AddInstruction(HloInstruction::CreateOutfeed( + effective_scalar_s32, bitcast_inst, in_token, "")); builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); @@ -352,9 +354,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { // The bitcast's user can be hoisted, so hoist the bitcast too. auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); - Shape while_shape = - ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32}); + auto effective_scalar_s32 = ShapeUtil::MakeShape(S32, {1}); + Shape while_shape = ShapeUtil::MakeTupleShape( + {scalar_s32, effective_scalar_s32, effective_scalar_s32}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -363,12 +365,13 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { HloInstruction* gte_0 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_f32, param, 1)); - HloInstruction* bitcast_inst = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction::CreateGetTupleElement(effective_scalar_s32, param, 1)); + HloInstruction* bitcast_inst = + builder.AddInstruction(HloInstruction::CreateUnary( + effective_scalar_s32, HloOpcode::kBitcast, gte_0)); HloInstruction* add_inst = builder.AddInstruction(HloInstruction::CreateBinary( - scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1)); + effective_scalar_s32, HloOpcode::kAdd, bitcast_inst, gte_1)); builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index d30f67dd8110b88166fe807762fb653190ec00bc..386ffb995477ff1b4aef73080b6a6fd988dd1980 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -58,7 +58,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { HloComputation* while_body = while_op->while_body(); HloInstruction* while_body_root = while_body->root_instruction(); - if (!ShapeUtil::IsTuple(while_init->shape())) { + if (!while_init->shape().IsTuple()) { VLOG(2) << "While op's carried value isn't tuple shaped."; return false; } @@ -109,8 +109,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // operand appears in, but it may appear more than once! if (user->user_count() == 1 && user->users().front() == while_body_root && while_body_root->operand_index(user) == user->tuple_index() && - std::count(while_body_root->operands().begin(), - while_body_root->operands().end(), user) == 1) { + absl::c_count(while_body_root->operands(), user) == 1) { continue; } @@ -127,7 +126,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // through to the while body's root, count that element as "used", since // removing that element would be observable. for (int64 i = 0; i < while_body_root->operand_count(); ++i) { - if (used_tuple_indices.count(i)) { + if (used_tuple_indices.contains(i)) { continue; } @@ -158,7 +157,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Build up maps from the old/new to the new/old tuple indices. std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), used_tuple_indices.end()); - std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); + absl::c_sort(new_to_old_tuple_idx); absl::flat_hash_map old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { @@ -181,7 +180,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // replace the old instructions after we remove unused elements from the while // tuple. auto make_while_computation_replacements = [&](const HloComputation* comp) { - std::unordered_map> + absl::flat_hash_map> replacements; auto* param = comp->parameter_instruction(0); @@ -233,7 +232,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { while_cond->CloneWithReplacements( make_while_computation_replacements(while_cond)); - std::unordered_map> + absl::flat_hash_map> while_body_replacements = make_while_computation_replacements(while_body); std::vector new_while_body_root_elems; new_while_body_root_elems.reserve(new_to_old_tuple_idx.size()); @@ -583,8 +582,7 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { static std::unique_ptr UnflattenTupleInstr( absl::Span instrs, const Shape& desired_shape, std::vector>* new_instrs) { - CHECK(ShapeUtil::IsTuple(desired_shape)) - << ShapeUtil::HumanString(desired_shape); + CHECK(desired_shape.IsTuple()) << ShapeUtil::HumanString(desired_shape); // For each child shape in `desired_shape`, slice out the correct number of // `instrs` and call UnflattenTupleInstr recursively. At each step we remove @@ -593,7 +591,7 @@ static std::unique_ptr UnflattenTupleInstr( std::vector elems; for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) { const Shape& subshape = desired_shape.tuple_shapes(i); - if (!ShapeUtil::IsTuple(subshape)) { + if (!subshape.IsTuple()) { elems.push_back(instrs[0]); instrs.remove_prefix(1); continue; @@ -603,7 +601,7 @@ static std::unique_ptr UnflattenTupleInstr( int64 num_leaves = 0; ShapeUtil::ForEachSubshape( subshape, [&](const Shape& s, const ShapeIndex& /*index*/) { - if (!ShapeUtil::IsTuple(s)) { + if (!s.IsTuple()) { ++num_leaves; } }); @@ -625,7 +623,7 @@ static std::vector GetFlatTupleElems( HloInstruction* instr, std::vector>* new_instrs) { const auto& shape = instr->shape(); - if (!ShapeUtil::IsTuple(shape)) { + if (!shape.IsTuple()) { return {instr}; } std::vector elems; @@ -665,7 +663,7 @@ static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { std::vector flattened_shape_elems; ShapeUtil::ForEachSubshape(while_shape, [&](const Shape& s, const ShapeIndex& /*index*/) { - if (!ShapeUtil::IsTuple(s)) { + if (!s.IsTuple()) { flattened_shape_elems.push_back(s); } }); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 4950e8269e9cf0723d717bd1734518d104c0c9f2..ecca76b1e86d833c73fbb9bad6a341660a7d2669 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -406,13 +407,12 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { // The original while instruction is still left in the module as a dead // instruction, find a while instruction with a different name as the new // while instruction. + const auto& instrs = m->entry_computation()->instructions(); HloInstruction* new_while_op = - *std::find_if(m->entry_computation()->instructions().begin(), - m->entry_computation()->instructions().end(), - [&](const HloInstruction* instr) { - return (instr->opcode() == HloOpcode::kWhile && - instr->name() != "while"); - }); + *absl::c_find_if(instrs, [&](const HloInstruction* instr) { + return (instr->opcode() == HloOpcode::kWhile && + instr->name() != "while"); + }); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); EXPECT_TRUE( @@ -554,8 +554,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { HloInstruction* new_while = FindFirstWhile(m.get()); Shape flat_tuple = - ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])") - .ValueOrDie(); + ParseShape("(s32[1], s32[2], s32[3], s32[4])").ValueOrDie(); SCOPED_TRACE(m->ToString()); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple)); EXPECT_TRUE(ShapeUtil::Equal( @@ -567,8 +566,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { flat_tuple)); EXPECT_TRUE(ShapeUtil::Equal( m->entry_computation()->root_instruction()->shape(), - ShapeUtil::ParseShapeString("((s32[1]), (s32[2], s32[3], (s32[4])))") - .ValueOrDie())); + ParseShape("((s32[1]), (s32[2], s32[3], (s32[4])))").ValueOrDie())); } // Edge-case: All elements of the loop carry are constants which can be removed, @@ -641,8 +639,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); HloInstruction* new_while = FindFirstWhile(m.get()); - Shape new_while_shape = - ShapeUtil::ParseShapeString("(s32[1], s32[3])").ValueOrDie(); + Shape new_while_shape = ParseShape("(s32[1], s32[3])").ValueOrDie(); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); EXPECT_TRUE(ShapeUtil::Equal( new_while->while_body()->root_instruction()->shape(), new_while_shape)); @@ -652,9 +649,9 @@ TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { EXPECT_TRUE(ShapeUtil::Equal( new_while->while_condition()->parameter_instruction(0)->shape(), new_while_shape)); - EXPECT_TRUE(ShapeUtil::Equal( - m->entry_computation()->root_instruction()->shape(), - ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3])").ValueOrDie())); + EXPECT_TRUE( + ShapeUtil::Equal(m->entry_computation()->root_instruction()->shape(), + ParseShape("(s32[1], s32[2], s32[3])").ValueOrDie())); EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(_, op::Constant(), _)); } @@ -712,7 +709,7 @@ TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) { // We should have added a new loop counter for s32[] to the end of the tuple. SCOPED_TRACE(m->ToString()); Shape new_while_shape = - ShapeUtil::ParseShapeString("(s32[], s32[], s32[], s32[])").ValueOrDie(); + ParseShape("(s32[], s32[], s32[], s32[])").ValueOrDie(); EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); EXPECT_TRUE(ShapeUtil::Equal( new_while->while_body()->root_instruction()->shape(), new_while_shape)); diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 039ccda7322f5efda6a827efbeda1225c3596cc0..d77386497a14b3e52be2ea7f655fa330f60e4a97 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -97,7 +97,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { WhileUtil::MakeInstructionsLiveIn( HloInstruction* while_instr, absl::Span instructions) { - CHECK(ShapeUtil::IsTuple(while_instr->shape())); + CHECK(while_instr->shape().IsTuple()); int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); Shape new_while_shape = while_instr->shape(); diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 5e6941933330fde29bc9c779aae4bb3c36914660..d92b9870f373564ae8fd904c8bf9f0d1afbff9c4 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -180,8 +180,8 @@ body { cond { param.c = (s32[], s32[]) parameter(0) - token = token[] after-all() - infeed = (pred[], token[]) infeed(token) + token0 = token[] after-all() + infeed = (pred[], token[]) infeed(token0) ROOT condition = pred[] get-tuple-element(infeed), index=0 } diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc index 83d696fe0915086c3c98b6d7cbdaeaeb4d9d0bdb..661b7aa7d99ca549da6a509812760a1665d60919 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -31,16 +31,21 @@ StatusOr ZeroSizedHloElimination::Run(HloModule* module) { bool changed = false; for (HloComputation* comp : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { - if (instruction->HasSideEffect() || - !ShapeUtil::IsArray(instruction->shape()) || + if (instruction->HasSideEffect() || !instruction->shape().IsArray() || instruction->opcode() == HloOpcode::kConstant) { continue; } if (comp->IsRemovable(instruction) && ShapeUtil::IsZeroElementArray(instruction->shape())) { + // If the instruction doesn't have a layout, use a default layout for + // the literal. + Shape shape = instruction->shape(); + if (!LayoutUtil::HasLayout(shape)) { + LayoutUtil::SetToDefaultLayout(&shape); + } TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( - instruction, HloInstruction::CreateConstant( - Literal::CreateFromShape(instruction->shape())))); + instruction, + HloInstruction::CreateConstant(Literal::CreateFromShape(shape)))); changed = true; } } diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index a546a6d39cc55d1f327b8449c7d26cd4c95dbf98..572a79609e7a912277af0fd2ba43f9a1e14a6f52 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -82,5 +82,18 @@ TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateConstant) { EXPECT_FALSE(changed); } +TEST_F(ZeroSizedHloEliminationTest, ZeroSizedInstructionWithoutLayoutFolded) { + Shape op_shape = ShapeUtil::MakeShape(F32, {4, 0}); + op_shape.clear_layout(); + HloInstruction* param1 = builder_.AddInstruction( + HloInstruction::CreateParameter(1, op_shape, "zero sized param 1")); + HloInstruction* param2 = builder_.AddInstruction( + HloInstruction::CreateParameter(2, op_shape, "zero sized param 2")); + builder_.AddInstruction( + HloInstruction::CreateBinary(op_shape, HloOpcode::kAdd, param1, param2)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_TRUE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index 746ab9e9977b1b10cdb0cb57197027d65bd50f55..a36d3547a0987422c2658b0f3046f7b1f83369c6 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -27,12 +27,27 @@ Shape::Shape(const ShapeProto& shape_proto) { for (const int64 dimension : shape_proto.dimensions()) { add_dimensions(dimension); } + // A malformed proto may have different is_dynamic_dimension_size and + // dimensions_size. Since C++ is evil, and we have no good way of bailing out + // in a constructor, conservatively trim the is_dynamic_dimension size. + // TODO(b/120111794): Make this a hard error when we have a factory method + // instead of a constructor. + if (shape_proto.dimensions_size() != + shape_proto.is_dynamic_dimension_size()) { + LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension " + "fields does not match number of dimension fields"; + } + int64 num_dynamic_dimension_fields = std::min( + shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size()); + for (int i = 0; i < num_dynamic_dimension_fields; i++) { + dynamic_dimensions_[i] = shape_proto.is_dynamic_dimension(i); + } tuple_shapes_.reserve(shape_proto.tuple_shapes_size()); for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) { *add_tuple_shapes() = Shape(element_shape); } if (shape_proto.has_layout()) { - *mutable_layout() = shape_proto.layout(); + *mutable_layout() = Layout::CreateFromProto(shape_proto.layout()); } } @@ -43,12 +58,15 @@ ShapeProto Shape::ToProto() const { for (const int64 dimension : dimensions()) { proto.add_dimensions(dimension); } + for (const bool dynamic : dynamic_dimensions_) { + proto.add_is_dynamic_dimension(dynamic); + } proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size()); for (const Shape& shape : tuple_shapes()) { *proto.add_tuple_shapes() = shape.ToProto(); } if (has_layout()) { - *proto.mutable_layout() = layout(); + *proto.mutable_layout() = layout().ToProto(); } return proto; } @@ -61,6 +79,112 @@ string Shape::ToString(bool print_layout) const { } } +bool Shape::is_static() const { + if (IsTuple()) { + for (const Shape& subshape : tuple_shapes_) { + if (!subshape.is_static()) { + return false; + } + } + } + return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; }); +} + +void Shape::DeleteDimension(int64 dim_to_delete) { + CHECK(IsArray()); + CHECK_GE(dim_to_delete, 0); + CHECK_LT(dim_to_delete, dimensions_.size()); + dimensions_.erase(dimensions_.begin() + dim_to_delete); + dynamic_dimensions_.erase(dynamic_dimensions_.begin() + dim_to_delete); + if (LayoutUtil::HasLayout(*this)) { + layout_.set_format(DENSE); + for (int64 i = 0; i < layout_.minor_to_major().size();) { + if (layout_.minor_to_major(i) == dim_to_delete) { + layout_.mutable_minor_to_major()->erase( + layout_.mutable_minor_to_major()->begin() + i); + continue; + } + if (layout_.minor_to_major(i) > dim_to_delete) { + (*layout_.mutable_minor_to_major())[i] -= 1; + } + ++i; + } + } +} + +bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { + if (lhs.IsTuple()) { + return rhs.IsTuple() && + absl::c_equal( + lhs.tuple_shapes(), rhs.tuple_shapes(), + [=](const Shape& l, const Shape& r) { return (*this)(l, r); }); + } else if (!lhs.IsArray()) { + // Non-tuple, non-array tupes such as opaque and token types are trivially + // the same. + return lhs.element_type() == rhs.element_type(); + } + + if (!rhs.IsArray()) { + return false; + } + + if (!ignore_element_type_) { + if ((ignore_fp_precision_ && + !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || + (!ignore_fp_precision_ && !ShapeUtil::SameElementType(lhs, rhs))) { + VLOG(3) << "CompareShapes: lhs element type != rhs element type"; + return false; + } + } + + if (!ignore_layout_) { + if (lhs.layout().format() != rhs.layout().format()) { + VLOG(3) << "CompareShapes: lhs layout format != rhs layout format"; + return false; + } + if (LayoutUtil::IsDenseArray(lhs)) { + if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs), + LayoutUtil::MinorToMajor(rhs))) { + VLOG(3) << "CompareShapes: lhs layout != rhs layout"; + return false; + } + + const auto& lhs_tiles = lhs.layout().tiles(); + const auto& rhs_tiles = rhs.layout().tiles(); + if (lhs_tiles.size() != rhs_tiles.size()) { + return false; + } + for (int64 i = 0; i < lhs_tiles.size(); i++) { + if (!absl::c_equal(lhs_tiles[i].dimensions(), + rhs_tiles[i].dimensions())) { + return false; + } + } + + if (lhs.layout().element_size_in_bits() != + rhs.layout().element_size_in_bits()) { + return false; + } + } + } + + if (!ShapeUtil::SameDimensions(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; + return false; + } + + if (!ignore_dynamic_dimension_) { + for (int i = 0; i < lhs.rank(); ++i) { + if (lhs.is_dynamic_dimension(i) != rhs.is_dynamic_dimension(i)) { + VLOG(3) + << "CompareShapes: lhs and rhs have different dynamic dimensions."; + return false; + } + } + } + return true; +} + std::ostream& operator<<(std::ostream& out, const Shape& shape) { out << shape.ToString(/*print_layout=*/true); return out; diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 7f6b14ab4286c696dce64d2250a3fe8a57e4865b..1d594904e0b9e6f1779674e75b41b7a597788bac 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -20,6 +20,8 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -43,6 +45,43 @@ class Shape { // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". string ToString(bool print_layout = false) const; + // Returns the rank (number of dimensions) of the given shape. Shape must be + // an array. + int64 rank() const { + CHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString(); + return dimensions_.size(); + } + + // Returns whether the shape is of the specified type (array, tuple, etc). + bool IsArray() const { return primitive_util::IsArrayType(element_type()); } + bool IsTuple() const { return element_type() == TUPLE; } + bool IsToken() const { return element_type() == TOKEN; } + bool IsOpaque() const { return element_type() == OPAQUE; } + + // Returns true if no array dimension in the shape is dynamically sized. Tuple + // shapes are traversed recursively. + bool is_static() const; + + // Returns true if the given dimension is dynamically-sized. + bool is_dynamic_dimension(int dimension) const { + return dynamic_dimensions_.at(dimension); + } + + // Sets whether or not the given dimension is dynamically-sized. + void set_dynamic_dimension(int dimension, bool is_dynamic) { + dynamic_dimensions_[dimension] = is_dynamic; + } + + const std::vector& dynamic_dimensions() const { + return dynamic_dimensions_; + } + + // Add dimension_upper_bound(). + + // Removes the given dimension form the shape. Layout, if it exists, is + // adjusted to match the modified shape. + void DeleteDimension(int64 dim_to_delete); + // The following methods mirror the protobuf generated code interface for the // message ShapeProto. This enabled easy migration of this data structure // from a proto to a proper C++ class. @@ -57,10 +96,16 @@ class Shape { int dimensions_size() const { return dimensions_.size(); } int64 dimensions(int index) const { return dimensions_.at(index); } void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; } - void add_dimensions(int64 value) { dimensions_.push_back(value); } - void clear_dimensions() { dimensions_.clear(); } + void add_dimensions(int64 value) { + dimensions_.push_back(value); + dynamic_dimensions_.push_back(false); + } + void clear_dimensions() { + dimensions_.clear(); + dynamic_dimensions_.clear(); + } const std::vector& dimensions() const { return dimensions_; } - std::vector* mutable_dimensions() { return &dimensions_; } + absl::Span mutable_dimensions() { return absl::MakeSpan(dimensions_); } // Methods for accessing the tuple subshapes. This field only non-empty for // tuple shapes. @@ -76,21 +121,10 @@ class Shape { std::vector* mutable_tuple_shapes() { return &tuple_shapes_; } // Methods for accessing the layout field. - bool has_layout() const { return layout_.has_value(); } - const Layout& layout() const { - if (layout_.has_value()) { - return *layout_; - } else { - return Layout::default_instance(); - } - } - Layout* mutable_layout() { - if (!layout_.has_value()) { - layout_ = Layout(); - } - return &layout_.value(); - } - void clear_layout() { layout_.reset(); } + bool has_layout() const { return layout_.format() != INVALID_FORMAT; } + const Layout& layout() const { return layout_; } + Layout* mutable_layout() { return &layout_; } + void clear_layout() { layout_.Clear(); } void Swap(Shape* other) { using std::swap; @@ -101,25 +135,74 @@ class Shape { element_type_ = PRIMITIVE_TYPE_INVALID; dimensions_.clear(); tuple_shapes_.clear(); - layout_.reset(); + clear_layout(); } string SerializeAsString() const { return ToProto().SerializeAsString(); } string ShortDebugString() const { return ToProto().ShortDebugString(); } string DebugString() const { return ToProto().DebugString(); } - public: + // Equal is a configurable functor to check the equality of two shapes. + // + // Examples: + // + // - Comparing two shapes ignoring they layout difference: + // Equal().IgnoreLayout()(shape1, shape2); + // + // - Comparing two shapes ignoring they layout and element type difference: + // Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2); + class Equal { + public: + Equal() = default; + + bool operator()(const Shape& lhs, const Shape& rhs); + + Equal& IgnoreLayout() { + ignore_layout_ = true; + return *this; + } + Equal& IgnoreElementType() { + ignore_element_type_ = true; + return *this; + } + Equal& IgnoreFpPrecision() { + ignore_fp_precision_ = true; + return *this; + } + Equal& IgnoreDynamicDimension() { + ignore_dynamic_dimension_ = true; + return *this; + } + + public: + bool ignore_layout_ = false; + bool ignore_element_type_ = false; + bool ignore_fp_precision_ = false; + bool ignore_dynamic_dimension_ = false; + }; + + // Test that all fields of the shape are the same, equivalent to Equal(). + bool operator==(const Shape& other) const { return Equal()(*this, other); } + bool operator!=(const Shape& other) const { return !(*this == other); } + + private: // The element type of this shape (tuple, array, etc). PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; - // The array bounds of the dimensions. This is nonempty only for array shapes. + // The array bounds of the dimensions. This is nonempty only for array + // shapes. For a dynamically-sized dimension, the respective value in this + // vector is an inclusive upper limit of the array bound. std::vector dimensions_; + // This vector is the same size as 'dimensions_' and indicates whether the + // respective dimension is dynamically sized. + std::vector dynamic_dimensions_; + // The tuple element subshapes. This is nonempty only for tuple shapes. std::vector tuple_shapes_; - // The array layout of the shape. This is present only for array shapes. - absl::optional layout_; + // The layout of the shape. Only relevant for arrays. + Layout layout_; }; // Shape of the parameters and output of an XLA computation. This is analogous diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index d44db89d571891ecef554cd45c050017833982bb..a000886d60d06a4a598910c901accb6dfd0a8f1a 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -52,7 +52,7 @@ bool ShapeLayout::MatchesLayoutInShape(const Shape& shape) const { const Layout& ShapeLayout::layout() const { CHECK(LayoutIsSet()); - CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!shape_.IsTuple()); return shape_.layout(); } @@ -61,15 +61,15 @@ void ShapeLayout::Clear() { LayoutUtil::ClearLayout(&shape_); } bool ShapeLayout::LayoutIsSet() const { return LayoutUtil::HasLayout(shape_); } void ShapeLayout::ResetLayout(const Layout& layout) { - CHECK(!ShapeUtil::IsTuple(shape_)); - CHECK(!ShapeUtil::IsOpaque(shape_)); + CHECK(!shape_.IsTuple()); + CHECK(!shape_.IsOpaque()); *shape_.mutable_layout() = layout; TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); } void ShapeLayout::ResetLayout(const Layout& layout, ShapeIndexView shape_index) { - CHECK(ShapeUtil::IsTuple(shape_)); + CHECK(shape_.IsTuple()); *ShapeUtil::GetMutableSubshape(&shape_, shape_index)->mutable_layout() = layout; TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc index e396897eeebc2e7bdc2dc49300c8906710608b05..526abafea5cc244418a4ec05db7da6203716b483 100644 --- a/tensorflow/compiler/xla/shape_test.cc +++ b/tensorflow/compiler/xla/shape_test.cc @@ -41,11 +41,13 @@ class ShapeTest : public ::testing::Test { ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_}); const Shape nested_tuple_ = ShapeUtil::MakeTupleShape({tuple_, matrix_, token_}); + const Shape dyanmic_matrix_ = + ShapeUtil::MakeShape(S32, {5, 2}, {true, false}); }; TEST_F(ShapeTest, ShapeToFromProto) { - for (const Shape& shape : - {opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_}) { + for (const Shape& shape : {opaque_, token_, scalar_, matrix_, matrix2_, + tuple_, nested_tuple_, dyanmic_matrix_}) { Shape shape_copy(shape.ToProto()); EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy)) << shape << " != " << shape_copy; @@ -74,6 +76,65 @@ TEST_F(ShapeTest, ShapeToString) { nested_tuple_.ToString(/*print_layout=*/true)); } +TEST_F(ShapeTest, DynamicShapeToString) { + Shape array_shape = + ShapeUtil::MakeShape(F32, {23, 44, 55}, {true, false, true}); + EXPECT_EQ("f32[<=23,44,<=55]", array_shape.ToString()); + + array_shape.set_dynamic_dimension(2, false); + EXPECT_EQ("f32[<=23,44,55]", array_shape.ToString()); +} + +TEST_F(ShapeTest, EqualityTest) { + // Different layouts. + EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {0, 1})); + + // Different dims. + EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {44, 23}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0})); + + // Different elements. + EXPECT_NE(ShapeUtil::MakeShapeWithLayout(S32, {44, 23}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0})); + + // Equal shapes. + EXPECT_EQ(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0})); +} + +TEST_F(ShapeTest, IsStatic) { + EXPECT_TRUE(opaque_.is_static()); + EXPECT_TRUE(token_.is_static()); + EXPECT_TRUE(matrix_.is_static()); + EXPECT_TRUE(tuple_.is_static()); + EXPECT_TRUE(nested_tuple_.is_static()); + + Shape dynamic_matrix = matrix_; + EXPECT_TRUE(dynamic_matrix.is_static()); + dynamic_matrix.set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_matrix.is_static()); + + Shape dynamic_tuple = tuple_; + EXPECT_TRUE(dynamic_tuple.is_static()); + ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) + ->set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_tuple.is_static()); +} + +TEST_F(ShapeTest, IsDynamicDimension) { + Shape dynamic_matrix = matrix_; + dynamic_matrix.set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_matrix.is_dynamic_dimension(0)); + EXPECT_TRUE(dynamic_matrix.is_dynamic_dimension(1)); + + Shape dynamic_tuple = tuple_; + EXPECT_TRUE(dynamic_tuple.is_static()); + ShapeUtil::GetMutableSubshape(&dynamic_tuple, {2}) + ->set_dynamic_dimension(1, true); + EXPECT_FALSE(dynamic_tuple.is_static()); +} + TEST_F(ShapeTest, ProgramShapeToFromProto) { ProgramShape program_shape; *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3}); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 7bf97729165bef98fabc29040e02203eee68a53c..089120179e2a77518eb5b18c11a35670b03e9b77 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -395,7 +395,7 @@ class ShapeTreeIterator template int64 ShapeTree::CountSubshapes(const Shape& shape) { int64 current_count = 1; - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { int64 count = ShapeUtil::TupleElementCount(shape); for (int i = 0; i < count; ++i) { current_count += CountSubshapes(shape.tuple_shapes(i)); @@ -407,7 +407,7 @@ int64 ShapeTree::CountSubshapes(const Shape& shape) { template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, Node* node, Index* index) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { const int64 size = ShapeUtil::TupleElementCount(shape); #ifndef NDEBUG index->children_count = size; @@ -443,7 +443,7 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, template void ShapeTree::InitChildren(const Shape& shape, Node* node, Index* index) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { const int64 size = ShapeUtil::TupleElementCount(shape); #ifndef NDEBUG index->children_count = size; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index b95fabf488291b0a7f393cb9f7f4a5dc9eb7c7eb..1ada4bc0362f86bc770d4adfcd4d4b0ff7379c77 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -81,73 +81,10 @@ bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const { /* static */ bool ShapeUtil::IsArrayPrimitiveType( PrimitiveType primitive_type) { - return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && - primitive_type != OPAQUE && primitive_type != TOKEN; + return primitive_util::IsArrayType(primitive_type); } namespace { - -// Recursive helper for comparing the equality of two shapes. Returns true if -// the shapes are the same. If compare_layouts is true, then layouts must also -// match. -bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, - bool ignore_fp_precision) { - if ((ignore_fp_precision && - !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || - (!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) { - VLOG(3) << "CompareShapes: lhs element type != rhs element type"; - return false; - } - - if (ShapeUtil::IsTuple(lhs)) { - return absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts, - ignore_fp_precision); - }); - } else if (!ShapeUtil::IsArray(lhs)) { - // Non-tuple, non-array tupes such as opaque and token types are trivially - // the same. - return true; - } - - if (compare_layouts) { - if (lhs.layout().format() != rhs.layout().format()) { - return false; - } - if (LayoutUtil::IsDenseArray(lhs)) { - if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs), - LayoutUtil::MinorToMajor(rhs))) { - VLOG(3) << "CompareShapes: lhs layout != rhs layout"; - return false; - } - - const auto& lhs_tiles = lhs.layout().tiles(); - const auto& rhs_tiles = rhs.layout().tiles(); - if (lhs_tiles.size() != rhs_tiles.size()) { - return false; - } - for (int64 i = 0; i < lhs_tiles.size(); i++) { - if (!absl::c_equal(lhs_tiles[i].dimensions(), - rhs_tiles[i].dimensions())) { - return false; - } - } - - if (lhs.layout().element_size_in_bits() != - rhs.layout().element_size_in_bits()) { - return false; - } - } - } - - if (!ShapeUtil::SameDimensions(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; - return false; - } - return true; -} - // Constructs and returns the new shape with the given minor_to_major order in // its Layout. StatusOr MakeShapeWithLayoutInternal( @@ -164,9 +101,9 @@ StatusOr MakeShapeWithLayoutInternal( TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape(element_type, dimensions)); auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); + min2maj->clear(); for (int64 value : minor_to_major) { - min2maj->Add(value); + min2maj->push_back(value); } if (!shape.has_layout()) { return InvalidArgument("Shape has no layout."); @@ -174,12 +111,11 @@ StatusOr MakeShapeWithLayoutInternal( TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); return shape; } - } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, - /*ignore_fp_precision=*/false); + bool equal = Shape::Equal()(lhs, rhs); + if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -190,8 +126,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, - /*ignore_fp_precision=*/true); + bool equal = Shape::Equal().IgnoreFpPrecision()(lhs, rhs); if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -200,12 +135,6 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } -/* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(ShapeUtil::IsArray(shape)) - << "Non-arrays do not have a rank, shape: " << shape; - return shape.dimensions_size(); -} - /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { int64 accum = 0; for (int64 dimension : shape.dimensions()) { @@ -232,14 +161,32 @@ StatusOr MakeShapeWithLayoutInternal( return MakeValidatedShape(element_type, dimensions).ValueOrDie(); } +/* static */ Shape ShapeUtil::MakeShape( + PrimitiveType element_type, absl::Span dimensions, + const std::vector& dynamic_dimensions) { + return MakeValidatedShape(element_type, dimensions, dynamic_dimensions) + .ValueOrDie(); +} + /* static */ StatusOr ShapeUtil::MakeValidatedShape( PrimitiveType element_type, absl::Span dimensions) { - CHECK(IsArrayPrimitiveType(element_type)); + CHECK(IsArrayPrimitiveType(element_type)) << element_type; Shape result; TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result)); return result; } +/* static */ StatusOr ShapeUtil::MakeValidatedShape( + PrimitiveType element_type, absl::Span dimensions, + const std::vector& dynamic_dimensions) { + TF_ASSIGN_OR_RETURN(Shape shape, + MakeValidatedShape(element_type, dimensions)); + for (int i = 0; i < dynamic_dimensions.size(); ++i) { + shape.set_dynamic_dimension(i, dynamic_dimensions[i]); + } + return shape; +} + /* static */ Shape ShapeUtil::MakeShapeWithLayout( PrimitiveType element_type, absl::Span dimensions, absl::Span minor_to_major) { @@ -319,7 +266,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { CHECK(LayoutUtil::IsDenseArray(*shape)); - shape->mutable_layout()->add_minor_to_major(Rank(*shape)); + shape->mutable_layout()->add_minor_to_major(shape->rank()); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); } @@ -334,7 +281,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { - if (!IsArray(shape)) { + if (!shape.IsArray()) { return false; } return primitive_util::BitWidth(shape.element_type()) == bits; @@ -358,6 +305,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( case U32: case U64: case C64: + case C128: case TUPLE: case OPAQUE: case TOKEN: @@ -376,27 +324,24 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return primitive_util::IsFloatingPointType(shape.element_type()); } -/* static */ bool ShapeUtil::IsArray(const Shape& shape) { - return IsArrayPrimitiveType(shape.element_type()); -} - /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { - return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), - shape.tuple_shapes().end(), IsTuple); + return shape.IsTuple() && + absl::c_any_of(shape.tuple_shapes(), + [](const Shape& s) { return s.IsTuple(); }); } /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) { - return IsTuple(shape) && TupleElementCount(shape) == 0; + return shape.IsTuple() && TupleElementCount(shape) == 0; } /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { - CHECK(IsTuple(shape)) << HumanString(shape); + CHECK(shape.IsTuple()) << HumanString(shape); return shape.tuple_shapes_size(); } /* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape, int64 index) { - CHECK(IsTuple(shape)); + CHECK(shape.IsTuple()); CHECK_GT(TupleElementCount(shape), index); TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index))); return shape.tuple_shapes(index); @@ -412,7 +357,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start, int64 limit) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple)); - CHECK(IsTuple(tuple)); + CHECK(tuple.IsTuple()); CHECK_LE(start, TupleElementCount(tuple)); CHECK_LE(limit, TupleElementCount(tuple)); @@ -429,15 +374,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( complex_shape.element_type())); } -/* static */ bool ShapeUtil::ShapeIs(const Shape& shape, - PrimitiveType element_type, - std::initializer_list dimensions) { - return Equal(shape, MakeShape(element_type, dimensions)); -} - /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); - DCHECK_EQ(shape.dimensions_size(), Rank(shape)); + DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape); + DCHECK_EQ(shape.dimensions_size(), shape.rank()); if (shape.dimensions().size() == 1) { return shape.dimensions()[0]; } @@ -447,8 +386,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) { - CHECK(IsArray(shape) || IsTuple(shape)); - if (IsArray(shape)) { + CHECK(shape.IsArray() || shape.IsTuple()); + if (shape.IsArray()) { return ElementsIn(shape); } int64 count = 0; @@ -472,7 +411,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { - return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; + return shape.IsArray() && ElementsIn(shape) == 0; } /* static */ bool ShapeUtil::IsScalarWithElementType( @@ -480,56 +419,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return IsScalar(shape) && shape.element_type() == element_type; } -namespace { - -// Class to memoize the computation of -// absl::AsciiStrToLower(PrimitiveType_Name(p)) -// for all PrimitiveType values "p" -class PrimitiveTypeNameGenerator { - public: - PrimitiveTypeNameGenerator() { - for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (PrimitiveType_IsValid(i)) { - lowercase_name_[i] = absl::AsciiStrToLower( - PrimitiveType_Name(static_cast(i))); - } - } - } - const string& LowercaseName(PrimitiveType t) { - return lowercase_name_[static_cast(t)]; - } - - private: - string lowercase_name_[PrimitiveType_ARRAYSIZE]; -}; - -const string& LowercasePrimitiveTypeName(PrimitiveType s) { - static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator(); - return gen->LowercaseName(s); -} - -StatusOr StringToPrimitiveType(const string& name) { - static std::unordered_map* name_to_type = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { - if (PrimitiveType_IsValid(i)) { - auto value = static_cast(i); - (*map)[LowercasePrimitiveTypeName(value)] = value; - } - } - return map; - }(); - auto found = name_to_type->find(name); - if (found == name_to_type->end()) { - return InvalidArgument("Invalid element type string: \"%s\".", name); - } - return found->second; -} - -} // namespace - /* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (IsTuple(shape)) { + if (shape.IsTuple()) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -539,12 +430,21 @@ StatusOr StringToPrimitiveType(const string& name) { text += ")"; return text; } - return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", - absl::StrJoin(shape.dimensions(), ","), "]"); + std::vector dim_elements; + for (int i = 0; i < shape.dimensions_size(); ++i) { + if (shape.is_dynamic_dimension(i)) { + dim_elements.push_back(StrCat("<=", shape.dimensions(i))); + } else { + dim_elements.push_back(StrCat(shape.dimensions(i))); + } + } + return StrCat( + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "[", + absl::StrJoin(dim_elements, ","), "]"); } /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { - if (IsTuple(shape)) { + if (shape.IsTuple()) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -554,12 +454,14 @@ StatusOr StringToPrimitiveType(const string& name) { text += ")"; return text; } - string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "["); + string result = StrCat( + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "["); for (int i = 0; i < shape.dimensions().size(); i++) { - StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); + StrAppend(&result, (i > 0) ? "," : "", + shape.is_dynamic_dimension(i) ? "<=" : "", shape.dimensions(i)); } result += "]"; - if (!IsScalar(shape) && IsArray(shape)) { + if (!IsScalar(shape) && shape.IsArray()) { if (LayoutUtil::HasLayout(shape)) { StrAppend(&result, LayoutUtil::HumanString(shape.layout())); } @@ -580,155 +482,25 @@ StatusOr StringToPrimitiveType(const string& name) { HumanString(program_shape.result())); } -namespace { -// Parses shapes with simple recursive descent structure -- consumes from the -// front of s and passes that view recursively as required. -StatusOr ParseShapeStringInternal(absl::string_view* s) { - *s = StripLeadingAsciiWhitespace(*s); - - if (absl::ConsumePrefix(s, "(")) { // Tuple. - std::vector shapes; - bool must_end = false; - while (true) { - if (absl::ConsumePrefix(s, ")")) { - break; - } else if (must_end) { - return InvalidArgument("Expected end of tuple; got: \"%s\"", *s); - } - shapes.emplace_back(); - TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - *s = StripLeadingAsciiWhitespace(*s); - must_end = !absl::ConsumePrefix(s, ","); - } - return ShapeUtil::MakeTupleShape(shapes); - } - - string element_type_string; - string dimensions_string; - string format_string; - string layout_string; - // absl::string_view is not compatible with internal RE2 StringPiece, so - // we convert in to the RE2-consumable type and then consume the corresponding - // amount from our string_view type. - static LazyRE2 shape_pattern = { - "^(\\w*\\d*)\\[([\\d,\\s]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,\\s]+)})" - "?"}; - tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); - if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string, - &dimensions_string, &format_string, &layout_string)) { - size_t consumed = s->size() - s_consumable.size(); - s->remove_prefix(consumed); - auto string_to_int64 = [&s](absl::string_view input) -> StatusOr { - int64 element; - if (!absl::SimpleAtoi(input, &element)) { - return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", input, - *s); - } - return element; - }; - - auto comma_list_to_int64s = - [string_to_int64](const string& input) -> StatusOr> { - std::vector results; - for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) { - TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); - results.push_back(element); - } - return results; - }; - - // Extract the dimensions. - TF_ASSIGN_OR_RETURN(std::vector dimensions, - comma_list_to_int64s(dimensions_string)); - - // Extract the primitive element type. - TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, - StringToPrimitiveType(element_type_string)); - if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { - return InvalidArgument("Invalid element type string: \"%s\".", - element_type_string); - } - - Shape result; - if (primitive_type == OPAQUE) { - result = ShapeUtil::MakeOpaqueShape(); - } else if (primitive_type == TOKEN) { - result = ShapeUtil::MakeTokenShape(); - } else if (format_string.empty() && layout_string.empty()) { - // Create a shape without a layout set. - TF_ASSIGN_OR_RETURN( - result, ShapeUtil::MakeValidatedShape(primitive_type, dimensions)); - } else if (format_string == "sparse") { - TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string)); - result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions, - max_elements); - } else if (format_string.empty() || format_string == "dense") { - // Extract the layout minor-to-major and set it. - TF_ASSIGN_OR_RETURN(std::vector min2maj, - comma_list_to_int64s(layout_string)); - TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( - primitive_type, dimensions, min2maj)); - } else { - // This should not be reached. - LOG(FATAL) << "Unhandled condition when parsing shape; format: \"" - << format_string << "\", layout: \"" << layout_string << "\""; - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); - return std::move(result); - } - - return InvalidArgument("Invalid shape string to parse: \"%s\"", *s); -} -} // namespace - -/* static */ StatusOr ShapeUtil::ParseShapeString(absl::string_view s) { - TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); - if (!s.empty()) { - return InvalidArgument("Invalid shape string to parse: \"%s\"", s); - } - return shape; -} - /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, const Shape& rhs) { - CHECK(ShapeUtil::IsArray(lhs)); - CHECK(ShapeUtil::IsArray(rhs)); + CHECK(lhs.IsArray()); + CHECK(rhs.IsArray()); return absl::c_equal(lhs.dimensions(), rhs.dimensions()); } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - return CompareShapes(lhs, rhs, /*compare_layouts=*/false, - /*ignore_fp_precision=*/false); + return Shape::Equal().IgnoreLayout()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - if (IsArray(lhs)) { - return IsArray(rhs) && SameDimensions(lhs, rhs); - } else if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringElementType); - } else { - // Opaque, token, etc types are vacuously compatible. - return lhs.element_type() == rhs.element_type(); - } + return Shape::Equal().IgnoreElementType().IgnoreLayout()(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - if (IsArray(lhs)) { - return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) && - CompatibleIgnoringElementType(lhs, rhs); - } else if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringFpPrecision); - } else { - // Opaque, token, etc types are vacuously compatible. - return lhs.element_type() == rhs.element_type(); - } + return Shape::Equal().IgnoreFpPrecision().IgnoreLayout()(lhs, rhs); } /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, @@ -739,7 +511,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape, int64 dimension_number) { if (dimension_number < 0) { - dimension_number += Rank(shape); + dimension_number += shape.rank(); } CHECK_GE(dimension_number, 0); return dimension_number; @@ -776,6 +548,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return sizeof(double); case C64: return sizeof(complex64); + case C128: + return sizeof(complex128); case TOKEN: // Tokens require no space. return 0; @@ -793,7 +567,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { TF_DCHECK_OK(ValidateShape(shape)); if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); - } else if (IsArray(shape)) { + } else if (shape.IsArray()) { int64 byte_size = ByteSizeOfElements(shape); if (LayoutUtil::IsSparseArray(shape)) { byte_size += ByteSizeOfSparseIndices(shape); @@ -819,7 +593,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); int64 allocated_element_count; if (LayoutUtil::IsSparseArray(shape)) { @@ -835,8 +609,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); CHECK(LayoutUtil::IsSparseArray(shape)); - return LayoutUtil::MaxSparseElements(shape.layout()) * - ShapeUtil::Rank(shape) * sizeof(int64); + return LayoutUtil::MaxSparseElements(shape.layout()) * shape.rank() * + sizeof(int64); } /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( @@ -867,22 +641,22 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { if (shape.dimensions_size() != 0) { return InvalidArgument( "shape has %s element type, but has dimensions field: %s", - LowercasePrimitiveTypeName(shape.element_type()), + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), shape.ShortDebugString()); } if (shape.has_layout()) { return InvalidArgument( "shape has %s element type, but has layout field: %s", - LowercasePrimitiveTypeName(shape.element_type()), + primitive_util::LowercasePrimitiveTypeName(shape.element_type()), shape.ShortDebugString()); } return Status::OK(); } - if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) { + if (LayoutUtil::IsSparseArray(shape) && shape.rank() == 0) { return InvalidArgument("sparse arrays must have rank > 0"); } - for (int64 i = 0; i < Rank(shape); ++i) { + for (int64 i = 0; i < shape.rank(); ++i) { int64 dimension = shape.dimensions(i); if (dimension < 0) { return InvalidArgument( @@ -898,7 +672,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) { VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape); - if (!IsArray(shape)) { + if (!shape.IsArray()) { return Status::OK(); } @@ -919,7 +693,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return sparse_elements_size; } int64 sparse_indices_size = - MultiplyWithoutOverflow(max_sparse_elements, ShapeUtil::Rank(shape)); + MultiplyWithoutOverflow(max_sparse_elements, shape.rank()); if (sparse_indices_size < 0) { return sparse_indices_size; } @@ -991,7 +765,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { ShapeIndexView index) { const Shape* subshape = &shape; for (auto i : index) { - if (!IsTuple(*subshape) || i >= subshape->tuple_shapes_size() || i < 0) { + if (!subshape->IsTuple() || i >= subshape->tuple_shapes_size() || i < 0) { return false; } subshape = &subshape->tuple_shapes(i); @@ -1003,7 +777,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { ShapeIndexView index) { const Shape* return_shape = &shape; for (auto i : index) { - CHECK(IsTuple(*return_shape)) + CHECK(return_shape->IsTuple()) << "Invalid index " << index << " for shape " << shape; return_shape = &return_shape->tuple_shapes(i); } @@ -1014,7 +788,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { const Shape& shape, ShapeIndexView index) { const Shape* return_shape = &shape; for (auto i : index) { - if (!IsTuple(*return_shape) || i < 0 || + if (!return_shape->IsTuple() || i < 0 || i >= return_shape->tuple_shapes_size()) { return InvalidArgument( "Shape index %s not a valid subshape index for tuple with shape %s", @@ -1029,7 +803,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { ShapeIndexView index) { Shape* return_shape = shape; for (auto i : index) { - CHECK(IsTuple(*return_shape)); + CHECK(return_shape->IsTuple()); return_shape = return_shape->mutable_tuple_shapes(i); } return return_shape; @@ -1037,11 +811,11 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { - return !IsTuple(GetSubshape(shape, index)); + return !GetSubshape(shape, index).IsTuple(); } /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { - if (!IsTuple(shape)) { + if (!shape.IsTuple()) { return 1; } int64 count = 0; @@ -1063,10 +837,15 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { } /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) { - CHECK(ShapeUtil::IsArray(shape)); + CHECK(shape.IsArray()); return absl::c_linear_search(shape.dimensions(), 1); } +/* static */ Shape ShapeUtil::DropDegenerateDimensions(const Shape& shape) { + return FilterDimensions( + [&](int64 dim) -> bool { return shape.dimensions()[dim] != 1; }, shape); +} + namespace { // Helper for ForEachSubshape which visits the subshapes of the given shape in @@ -1075,7 +854,7 @@ Status ForEachSubshapeHelper(const Shape& shape, const ShapeUtil::StatusVisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { index->push_back(i); TF_RETURN_IF_ERROR(ForEachSubshapeHelper( @@ -1092,7 +871,7 @@ Status ForEachMutableSubshapeHelper( Shape* shape, const ShapeUtil::MutatingStatusVisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); - if (ShapeUtil::IsTuple(*shape)) { + if (shape->IsTuple()) { for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) { index->push_back(i); TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper( @@ -1150,6 +929,10 @@ Status ForEachMutableSubshapeHelper( for (auto dim : Permute(permutation, shape.dimensions())) { new_shape.add_dimensions(dim); } + for (int64 i = 0; i < shape.rank(); i++) { + new_shape.set_dynamic_dimension(permutation[i], + shape.is_dynamic_dimension(i)); + } // If `shape` has a layout, by contract we choose a new layout such that the // transpose defined by this permutation is a bitcast. @@ -1168,7 +951,7 @@ Status ForEachMutableSubshapeHelper( // Let the argument `permutation` be P. This is a permutation over `shape`'s // dimensions, so our return value will be a shape with dims P.I = P. Our // goal is to construct a layout permutation L* that we can apply to P such - // that that the physical dimension ordering of the returned shape is the same + // that the physical dimension ordering of the returned shape is the same // as that of the original shape, namely L'. // // Our returned shape has dims P and layout L*, so its in-memory layout is @@ -1200,8 +983,8 @@ Status ForEachMutableSubshapeHelper( /* static */ std::tuple, std::vector> ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, const Shape& shape_post) { - CHECK(IsArray(shape_pre)); - CHECK(IsArray(shape_post)); + CHECK(shape_pre.IsArray()); + CHECK(shape_post.IsArray()); auto nil = std::make_tuple(false, std::vector(), std::vector()); @@ -1248,7 +1031,7 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, auto unmodified_dim_pair = i < unmodified_dims.size() ? unmodified_dims[i] - : std::make_pair(Rank(shape_pre), Rank(shape_post)); + : std::make_pair(shape_pre.rank(), shape_post.rank()); if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) { return nil; } @@ -1260,8 +1043,8 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, /* static */ std::vector> ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); // Unmodified dimensions are merely common factors of rank 1. auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), @@ -1311,8 +1094,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); CHECK(LayoutUtil::HasLayout(input_shape)); CHECK(LayoutUtil::HasLayout(output_shape)); @@ -1440,12 +1223,12 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, Shape output_shape_dim0_major = MakeShapeWithDescendingLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions())); - for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) { + for (int64 input_dim = 0; input_dim < input_shape.rank(); ++input_dim) { if (input_shape.dimensions(input_dim) <= 1) { continue; } - std::vector input_unit_index(Rank(input_shape), 0); + std::vector input_unit_index(input_shape.rank(), 0); input_unit_index[input_dim] = 1; int64 logical_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major, @@ -1471,11 +1254,11 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ absl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { - CHECK(IsArray(input_shape)); - CHECK(IsArray(output_shape)); + CHECK(input_shape.IsArray()); + CHECK(output_shape.IsArray()); - int64 input_rank = Rank(input_shape); - int64 output_rank = Rank(output_shape); + int64 input_rank = input_shape.rank(); + int64 output_rank = output_shape.rank(); // First, calculate an alignment of the dimensions. A consecutive sequence of // input dimensions and output dimensions belong to the same alignment part if @@ -1612,30 +1395,14 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { - CHECK(IsArray(shape)); - shape.mutable_dimensions()->erase(shape.mutable_dimensions()->begin() + - dim_to_delete); - if (LayoutUtil::HasLayout(shape)) { - Layout* layout = shape.mutable_layout(); - layout->set_format(DENSE); - for (size_t i = 0; i < layout->minor_to_major().size();) { - if (layout->minor_to_major(i) == dim_to_delete) { - layout->mutable_minor_to_major()->erase( - layout->minor_to_major().begin() + i); - continue; - } - if (layout->minor_to_major(i) > dim_to_delete) { - (*layout->mutable_minor_to_major())[i] -= 1; - } - ++i; - } - } + CHECK(shape.IsArray()); + shape.DeleteDimension(dim_to_delete); return shape; } /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { - CHECK(IsArray(shape)); + CHECK(shape.IsArray()); std::vector dims_to_delete; for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { if (!p(i)) { @@ -1655,8 +1422,11 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, size_t hash_value = hash()(shape.element_type()); if (shape.tuple_shapes().empty()) { - for (int64 dim : shape.dimensions()) { - hash_value = Hash64Combine(hash_value, hash()(dim)); + for (int i = 0; i < shape.dimensions_size(); ++i) { + hash_value = + Hash64Combine(hash_value, hash()(shape.dimensions(i))); + hash_value = Hash64Combine(hash_value, + hash()(shape.is_dynamic_dimension(i))); } hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout())); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 84a27f662a57ba274562e2e9be57b7e971c9b477..fb6da7460e2475732d6f02758e5519fbdb7c0f8d 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -185,7 +185,7 @@ class ShapeUtil { // may not actually be able to store this number of elements. See // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of // elements that can be stored in a sparse shape. - // Precondition: IsArray(shape) + // Precondition: shape.IsArray() static int64 ElementsIn(const Shape& shape); // As ElementsIn(), but recurses through tuples. @@ -207,7 +207,7 @@ class ShapeUtil { // Returns the number of bytes used to store the primitive_type. // - // Precondition: ShapeUtil::IsArray(shape) + // Precondition: shape.IsArray() static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); // Returns the number of bytes required to store the tuple member pointers for @@ -241,10 +241,6 @@ class ShapeUtil { // (param_name: f32[42x12], ...) -> f32[24x42] static string HumanString(const ProgramShape& program_shape); - // Parses a ShapeUtil::HumanString-format shape string back into a shape - // object. - static StatusOr ParseShapeString(absl::string_view s); - // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. // Precondition: IsArray(lhs) && IsArray(rhs) @@ -266,7 +262,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that that they have the same element type + // point types; otherwise, checks that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { @@ -294,16 +290,12 @@ class ShapeUtil { // being F32. Tuple elements are compared recursively for compatibility. static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); - // Returns whether the lhs and rhs shapes are identical protobufs. + // Returns whether the lhs and rhs shapes are identical. static bool Equal(const Shape& lhs, const Shape& rhs); // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); - // Returns the rank (number of dimensions) of the given shape. - // Precondition: !IsTuple(shape) - static int64 Rank(const Shape& shape); - // Returns the number of dimensions for which the dimension is not (trivially) // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just // fluff. Note that zero dimensions are included in the true rank, e.g., @@ -317,10 +309,10 @@ class ShapeUtil { // Scalar-specific static bool IsScalar(const Shape& shape) { - return IsArray(shape) && Rank(shape) == 0; + return shape.IsArray() && shape.rank() == 0; } static bool IsEffectiveScalar(const Shape& shape) { - return IsArray(shape) && TrueRank(shape) == 0; + return shape.IsArray() && TrueRank(shape) == 0; } // Returns whether "shape" is a scalar (array) with the given element_type. @@ -375,11 +367,24 @@ class ShapeUtil { static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions); + // Constructs a new shape with the given element type and sequence of + // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates + // with a true value that the respective dimension is dynamic. If the + // dimension is dynamic then the respective value in 'dimension' is an upper + // bound on the dimension size. 'dimensions' and 'dynamic_dimensions' must be + // the same size. + static Shape MakeShape(PrimitiveType element_type, + absl::Span dimensions, + const std::vector& dynamic_dimensions); + // Constructs a new shape with the given element type and sequence of // dimensions. Method checks if the element type is valid and the shape's // size fits in std::numeric_limits::max(). static StatusOr MakeValidatedShape(PrimitiveType element_type, absl::Span dimensions); + static StatusOr MakeValidatedShape( + PrimitiveType element_type, absl::Span dimensions, + const std::vector& dynamic_dimensions); // Creates a Shape with element type corresponding to T and the given // dimensions @@ -447,27 +452,6 @@ class ShapeUtil { // that floating point numbers are signed. static bool ElementIsSigned(const Shape& shape); - // Returns whether the shape is a tuple. - static bool IsTuple(const Shape& shape) { - return shape.element_type() == TUPLE; - } - - // Returns whether the shape is an opaque value (i.e. an 'existential' typed - // value that is passed to CustomCall operations). - static bool IsOpaque(const Shape& shape) { - return shape.element_type() == OPAQUE; - } - - // Returns whether the shape is an token value used for ordering - // side-effecting operations. - static bool IsToken(const Shape& shape) { - return shape.element_type() == TOKEN; - } - - // Returns whether the shape is an array. Note that scalars are considered - // arrays. - static bool IsArray(const Shape& shape); - // Returns whether the given primitive type corresponds to an array shape. static bool IsArrayPrimitiveType(PrimitiveType primitive_type); @@ -497,12 +481,6 @@ class ShapeUtil { // shape. static Shape ComplexComponentShape(const Shape& complex_shape); - // Shorthand for testing whether a shape is of a given element type and - // sequence of dimensions. - ABSL_DEPRECATED("Use Equal() instead.") - static bool ShapeIs(const Shape& shape, PrimitiveType element_type, - std::initializer_list dimensions); - // Returns true if the given shape has a subshape at the given index. static bool IndexIsValid(const Shape& shape, ShapeIndexView index); @@ -551,6 +529,9 @@ class ShapeUtil { // (dimensions with bound 1). static bool HasDegenerateDimensions(const Shape& shape); + // Drops any degenerate dimensions (i.e. dimensions of size 1) + static Shape DropDegenerateDimensions(const Shape& shape); + // Permutes the dimensions by the given permutation, so // return_value.dimensions[permutation[i]] = argument.dimensions[i]. // @@ -694,11 +675,9 @@ class ShapeUtil { template static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { - ForEachIndexWithStatus(shape, - [&](absl::Span indices) { - return StatusOr(visitor_function(indices)); - }) - .IgnoreError(); + ForEachIndexWithStatus(shape, [&](absl::Span indices) { + return StatusOr(visitor_function(indices)); + }).IgnoreError(); } // A parallel version of ForEachIndex(WithStatus). This can only be used if @@ -747,7 +726,7 @@ class ShapeUtil { if (ShapeUtil::IsZeroElementArray(shape)) { return Status::OK(); } - CHECK_EQ(Rank(shape), base.size()); + CHECK_EQ(shape.rank(), base.size()); CHECK_EQ(incr.size(), base.size()); CHECK_EQ(count.size(), base.size()); const int64 rank = LayoutUtil::MinorToMajor(shape).size(); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 60bdbe302045e6f3b4bae500c50bc68fb217525d..126ae58293d12182e9b6e30f779f681829729526 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -82,102 +82,6 @@ TEST(ShapeUtilTest, Rank4DimensionIndexing) { ASSERT_EQ(3, shape.dimensions(0)); } -TEST(ShapeUtilTest, ParseShapeStringR2F32) { - string shape_string = "f32[123,456]"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { - string shape_string = "(f32[1572864],s8[5120,1024])"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), - ShapeUtil::MakeShape(S8, {5120, 1024})}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { - string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeTupleShape({ - ShapeUtil::MakeShape(F32, {1}), - ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), - ShapeUtil::MakeOpaqueShape(), - ShapeUtil::MakeShape(F32, {3}), - }); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithLayout) { - string shape_string = "f32[123,456]{0,1}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) { - string shape_string = "f32[123,456]dense{0,1}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { - string shape_string = "f32[123,456]sparse{10}"; - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString(shape_string)); - Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseOpaqueType) { - TF_ASSERT_OK_AND_ASSIGN(Shape actual, - ShapeUtil::ParseShapeString("opaque[]")); - Shape expected = ShapeUtil::MakeOpaqueShape(); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseTokenType) { - TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]")); - Shape expected = ShapeUtil::MakeTokenShape(); - ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) - << "expected: " << ShapeUtil::HumanString(expected) - << "actual: " << ShapeUtil::HumanString(actual); -} - -TEST(ShapeUtilTest, ParseInvalidShapeString) { - string shape_strings[] = { - "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", - "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", - }; - for (const string& shape_string : shape_strings) { - StatusOr result = ShapeUtil::ParseShapeString(shape_string); - ASSERT_FALSE(result.ok()) << "shape: " << shape_string; - } -} - TEST(ShapeUtilTest, CompatibleIdenticalShapes) { Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); @@ -272,6 +176,28 @@ TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) { ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1}))); } +TEST(ShapeUtilTest, EqualDynamicShapes) { + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}), + ShapeUtil::MakeShape(F32, {4, 3}, {true, false}))); + EXPECT_FALSE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 3}, {true, false}), + ShapeUtil::MakeShape(F32, {4, 3}, {false, false}))); +} + +TEST(ShapeUtilTest, CompatibleDynamicShapes) { + Shape shape_a = ShapeUtil::MakeShape(F32, {4, 3}, {true, false}); + *shape_a.mutable_layout() = Layout({1, 0}); + Shape shape_b = ShapeUtil::MakeShape(F32, {4, 3}, {true, false}); + *shape_b.mutable_layout() = Layout({0, 1}); + Shape shape_c = ShapeUtil::MakeShape(F32, {4, 3}, {false, true}); + *shape_c.mutable_layout() = Layout({0, 1}); + + EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_a)); + EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_b)); + EXPECT_FALSE(ShapeUtil::Compatible(shape_a, shape_c)); +} + TEST(ShapeUtilTest, CompatibleTuples) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); @@ -612,10 +538,6 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); } -TEST(ShapeUtilTest, ShapeIs) { - EXPECT_FALSE(ShapeUtil::ShapeIs(ShapeUtil::MakeShape(PRED, {2}), PRED, {})); -} - TEST(ShapeUtilTest, ForEachIndex) { struct ShapeDimensionAndNumberInvocations { std::vector dimensions; @@ -788,6 +710,26 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) { } while (std::next_permutation(layout.begin(), layout.end())); } +TEST(ShapeUtilTest, PermuteDynamicDimensions) { + Shape shape = + ShapeUtil::MakeShape(F32, {10, 100, 1000}, + /*dynamic_dimensions*/ {false, true, true}); + SCOPED_TRACE(absl::StrCat("shape=", shape.ToString())); + + std::vector permutation(3); + std::iota(permutation.begin(), permutation.end(), 0); + do { + SCOPED_TRACE(absl::StrCat("permutation=", absl::StrJoin(permutation, ","))); + + auto permuted = ShapeUtil::PermuteDimensions(permutation, shape); + for (int i = 0; i < shape.rank(); i++) { + EXPECT_EQ(permuted.dimensions(permutation[i]), shape.dimensions(i)); + EXPECT_EQ(permuted.is_dynamic_dimension(permutation[i]), + shape.is_dynamic_dimension(i)); + } + } while (std::next_permutation(permutation.begin(), permutation.end())); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc index a40bb7875e7ea53a8959a9a67ec09ec260ba9c37..82091bdee65c709bb6020f40acc15f13d8599c1d 100644 --- a/tensorflow/compiler/xla/sparse_index_array.cc +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -79,7 +79,7 @@ void SparseIndexArray::Resize(int64 num_indices) { } bool SparseIndexArray::Validate(const Shape& shape) const { - if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) { + if (rank_ == 0 || rank_ != shape.rank()) { return false; } int64 num_indices = index_count(); diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h index a96d483462efd77ae4761541e8c79b2c84fa49f3..0c25355467da3fd346d80db790d78252869975ef 100644 --- a/tensorflow/compiler/xla/sparse_index_array.h +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -135,7 +135,7 @@ void SparseIndexArray::SortWithValues(absl::Span values) { auto sort_order_less = [this](int64 lhs, int64 rhs) { return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0; }; - std::sort(sort_order.begin(), sort_order.end(), sort_order_less); + absl::c_sort(sort_order, sort_order_less); // Reorder the array elements according to sort_order. Work through the array // and follow cycles so we can do the reorder in-place. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 2c18e2fd10105b6f0c146cad1842c7723699c8d9..e8e779fb2a3f201ae056e6385eacfe6a63503749 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1,6 +1,13 @@ # Description: # Base testing infrastructure for XLA. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") + licenses(["notice"]) # Apache 2.0 package( @@ -23,17 +30,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") -load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", -) - # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -75,6 +71,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", @@ -280,9 +277,6 @@ cc_library( xla_test( name = "bad_rng_shape_validation_test", srcs = ["bad_rng_shape_validation_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -303,10 +297,51 @@ xla_test( name = "conv_depthwise_test", timeout = "long", srcs = ["conv_depthwise_test.cc"], + shard_count = 50, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + +xla_test( + name = "conv_depthwise_backprop_filter_test", + timeout = "long", + srcs = ["conv_depthwise_backprop_filter_test.cc"], + shard_count = 6, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + +xla_test( + name = "grouped_convolution_test", + timeout = "long", + srcs = ["grouped_convolution_test.cc"], blacklisted_backends = [ # disabled because of a break b/119590850. - "cpu", "gpu", + # disabled because it times out. + "cpu", ], shard_count = 50, deps = [ @@ -327,9 +362,6 @@ xla_test( xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -350,9 +382,6 @@ xla_test( xla_test( name = "query_inferred_shape_test", srcs = ["query_inferred_shape_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -370,9 +399,6 @@ xla_test( xla_test( name = "while_test", srcs = ["while_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -396,6 +422,10 @@ xla_test( xla_test( name = "xla_hlo_profile_test", srcs = ["xla_hlo_profile_test.cc"], + blacklisted_backends = [ + # Hlo profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", @@ -419,9 +449,6 @@ xla_test( xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -436,7 +463,6 @@ xla_test( xla_test( name = "map_test", srcs = ["map_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -489,9 +515,6 @@ xla_test( xla_test( name = "pred_test", srcs = ["pred_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla/client:local_client", @@ -507,9 +530,6 @@ xla_test( xla_test( name = "select_test", srcs = ["select_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -527,7 +547,6 @@ xla_test( xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -545,7 +564,6 @@ xla_test( xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", @@ -606,9 +624,6 @@ xla_test( xla_test( name = "deconstruct_tuple_test", srcs = ["deconstruct_tuple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -631,7 +646,6 @@ xla_test( name = "array_elementwise_ops_test", srcs = ["array_elementwise_ops_test.cc"], shard_count = 25, - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -656,22 +670,19 @@ xla_test( xla_test( name = "exhaustive_f32_elementwise_op_test", - size = "enormous", srcs = ["exhaustive_f32_elementwise_op_test.cc"], - backends = [ - "cpu", - "gpu", - ], + real_hardware_only = True, # Very slow on the interpreter. shard_count = 48, tags = [ - "broken", - "manual", - "notap", + "optonly", + # This is a big test that we skip for capacity reasons in OSS testing. + "nooss", ], deps = [ ":client_library_test_base", ":literal_test_util", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/base", @@ -681,7 +692,6 @@ xla_test( xla_test( name = "reduce_precision_test", srcs = ["reduce_precision_test.cc"], - tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -708,7 +718,6 @@ xla_test( srcs = ["dot_operation_test.cc"], shard_count = 20, tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -718,7 +727,9 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -775,7 +786,9 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -789,9 +802,6 @@ xla_test( xla_test( name = "transpose_test", srcs = ["transpose_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -811,9 +821,6 @@ xla_test( xla_test( name = "constants_test", srcs = ["constants_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -824,7 +831,9 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -933,6 +942,11 @@ xla_test( xla_test( name = "batch_normalization_test", srcs = ["batch_normalization_test.cc"], + blacklisted_backends = [ + # BatchNorm HLOs are not handled by the interpreter backend, and the + # BatchNorm expander is not run on the interpreter. + "interpreter", + ], shard_count = 40, deps = [ ":test_utils", @@ -1024,9 +1038,6 @@ xla_test( name = "slice_test", srcs = ["slice_test.cc"], shard_count = 40, - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -1047,9 +1058,6 @@ xla_test( xla_test( name = "multidimensional_slice_test", srcs = ["multidimensional_slice_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1067,9 +1075,6 @@ xla_test( name = "dynamic_ops_test", timeout = "moderate", srcs = ["dynamic_ops_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -1095,9 +1100,6 @@ xla_test( xla_test( name = "tuple_test", srcs = ["tuple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -1121,9 +1123,6 @@ xla_test( xla_test( name = "vector_ops_reduce_test", srcs = ["vector_ops_reduce_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1144,7 +1143,6 @@ xla_test( srcs = ["reduce_test.cc"], shard_count = 40, tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -1211,7 +1209,6 @@ xla_test( srcs = [], shard_count = 20, tags = [ - "enable_for_xla_interpreter", "optonly", ], xla_test_library_deps = [":reduce_window_test_library"], @@ -1223,7 +1220,6 @@ xla_test( timeout = "long", srcs = ["select_and_scatter_test.cc"], tags = [ - "enable_for_xla_interpreter", "optonly", ], deps = [ @@ -1249,9 +1245,6 @@ xla_test( xla_test( name = "copy_test", srcs = ["copy_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla:array2d", @@ -1272,9 +1265,6 @@ xla_test( xla_test( name = "reduce_hlo_test", srcs = ["reduce_hlo_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1288,9 +1278,6 @@ xla_test( xla_test( name = "token_hlo_test", srcs = ["token_hlo_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", @@ -1305,9 +1292,6 @@ xla_test( xla_test( name = "call_test", srcs = ["call_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -1327,6 +1311,7 @@ xla_test( xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], + backends = ["cpu"], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -1349,9 +1334,6 @@ xla_test( xla_test( name = "binop_scaling_test", srcs = ["binop_scaling_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1369,9 +1351,6 @@ xla_test( xla_test( name = "broadcast_simple_test", srcs = ["broadcast_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1391,9 +1370,6 @@ xla_test( xla_test( name = "pad_test", srcs = ["pad_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1415,9 +1391,6 @@ xla_test( xla_test( name = "fmax_test", srcs = ["fmax_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1432,9 +1405,6 @@ xla_test( xla_test( name = "log_test", srcs = ["log_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1449,9 +1419,6 @@ xla_test( xla_test( name = "matrix_ops_simple_test", srcs = ["matrix_ops_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -1498,9 +1465,6 @@ xla_test( name = "reshape_test", srcs = ["reshape_test.cc"], shard_count = 30, - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1526,9 +1490,6 @@ xla_test( xla_test( name = "reverse_test", srcs = ["reverse_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1547,9 +1508,6 @@ xla_test( xla_test( name = "vector_ops_simple_test", srcs = ["vector_ops_simple_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:shape_util", @@ -1573,9 +1531,6 @@ xla_test( xla_test( name = "concat_test", srcs = ["concat_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1596,9 +1551,6 @@ xla_test( xla_test( name = "convert_test", srcs = ["convert_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1616,8 +1568,12 @@ xla_test( ) xla_test( - name = "cross_replica_sum_test", - srcs = ["cross_replica_sum_test.cc"], + name = "all_reduce_test", + srcs = ["all_reduce_test.cc"], + blacklisted_backends = [ + # All reduce is not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1642,9 +1598,6 @@ xla_test( xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1684,9 +1637,6 @@ xla_test( xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1748,6 +1698,10 @@ xla_test( xla_test( name = "execution_profile_test", srcs = ["execution_profile_test.cc"], + blacklisted_backends = [ + # Execution profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", @@ -1762,6 +1716,10 @@ xla_test( name = "execution_profile_test_with_xla_hlo_profile", srcs = ["execution_profile_test.cc"], args = ["--xla_hlo_profile"], + blacklisted_backends = [ + # Hlo profiles are not supported on the interpreter backend. + "interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", @@ -1775,9 +1733,6 @@ xla_test( xla_test( name = "replay_test", srcs = ["replay_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", @@ -1800,9 +1755,6 @@ xla_test( xla_test( name = "broadcast_test", srcs = ["broadcast_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1864,9 +1816,6 @@ xla_test( xla_test( name = "fusion_test", srcs = ["fusion_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -1984,6 +1933,10 @@ xla_test( xla_test( name = "outfeed_in_nested_computation_test", srcs = ["outfeed_in_nested_computation_test.cc"], + blacklisted_backends = [ + # Outfeed ops are not supported on the interpreter backend. + "interpreter", + ], deps = [ "//tensorflow/compiler/xla/tests:local_client_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -2160,7 +2113,6 @@ xla_test( srcs = ["iota_test.cc"], shard_count = 30, tags = [ - "enable_for_xla_interpreter", # Require optimized builds, iota_test_cpu is very slow in fastbuild. "optonly", ], @@ -2188,3 +2140,18 @@ tf_cc_test( "@com_google_absl//absl/synchronization", ], ) + +xla_test( + name = "ptxas_bug_120501638", + srcs = ["ptxas_bug_120501638.cc"], + tags = [ + # Disabled in OSS until nvidia publicly releases a fixed ptxas. + "no_oss", + ], + deps = [ + ":hlo_test_base", + ":xla_internal_test_main", # fixdeps: keep + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:test", + ], +) diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/all_reduce_test.cc similarity index 94% rename from tensorflow/compiler/xla/tests/cross_replica_sum_test.cc rename to tensorflow/compiler/xla/tests/all_reduce_test.cc index 410732c07b7b6d3ece33ab11f4778241dc53ca50..7e695f829e39831e2c8558cb07d0689e560bbafa 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/all_reduce_test.cc @@ -41,7 +41,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { ENTRY test_computation { p = f32[3] parameter(0) - ROOT crs = f32[3] cross-replica-sum(p), to_apply=add + ROOT crs = f32[3] all-reduce(p), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -62,7 +62,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] parameter(1) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -88,7 +88,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] constant({10, 20}) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + ROOT crs = (f32[3], f32[2]) all-reduce(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 0615f9425c1289d666641f4d581946b44b4895ce..7379fbcc22745f46f2a29732c4bda46f352d07e7 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -329,13 +329,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); - auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); - auto b_param = ConstantR1(&builder, b_values); + auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param"); + auto b_constant = ConstantR1(&builder, b_values); - auto sum1 = Add(a_constant, b_constant); - auto sum2 = Add(a_constant, b_param); - auto sum3 = Add(a_param, b_constant); - auto sum4 = Add(a_param, b_param); + auto sum1 = Add(a_constant, b_param); + auto sum2 = Add(a_constant, b_constant); + auto sum3 = Add(a_param, b_param); + auto sum4 = Add(a_param, b_constant); auto sum = Add(sum1, sum2); sum = Add(sum, sum3); @@ -350,6 +350,44 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { error_spec_); } +// TODO(b/119692968): This test runs OOM on the GPU and CPU backend. +XLA_TEST_F(ArrayElementwiseOpTest, + DISABLED_ON_GPU(DISABLED_ON_CPU(DeeplyNestedAddWithSlices))) { + XlaBuilder builder(TestName()); + std::vector values(30, 0.0); + auto a_literal = LiteralUtil::CreateR1(values); + auto a = Parameter(&builder, 0, a_literal.shape(), "x"); + auto b_literal = LiteralUtil::CreateR1(values); + auto b = Parameter(&builder, 1, b_literal.shape(), "x"); + + // Construct a sequence of diamond-shaped gadgets like this: + // + // add + // / \ + // slice slice + // \ / + // add + // + // Each 'left' slice removes the last element, each 'right' slice removes the + // first element. In this way, we index into the add with different + // multi-dimensional index arrays, which defeats the caching we use to avoid + // exponential compile time. + std::function generate_recursive = + [&](int64 slice_size) -> XlaOp { + if (slice_size == values.size()) { + return Add(a, b); + } + XlaOp param = generate_recursive(slice_size + 1); + auto slice1 = Slice(param, {0}, {slice_size}, {1}); + auto slice2 = Slice(param, {1}, {slice_size + 1}, {1}); + return Add(slice1, slice2); + }; + generate_recursive(1); + auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); + ComputeAndCompareR1(&builder, {0.0}, {a_data.get(), b_data.get()}); +} + XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); @@ -1405,6 +1443,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto lhs = + ConstantR1(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f}); + auto rhs = + ConstantR1(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f}); + Pow(lhs, rhs); + + ComputeAndCompareR1(&builder, + { + {0, 1.41421356}, + {-2.27443288e-01, 0.69999846}, + {-4.19847531e-01, -1.29215783}, + {0, 0}, + {0, 0}, + {1, 0}, + }, + {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); @@ -2009,6 +2068,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto minimum = ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN}); + auto argument = + ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); + auto maximum = ConstantR1(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f}); + Clamp(minimum, argument, maximum); + + ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {}, + error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XlaBuilder builder(TestName()); auto minimum = ConstantR0(&builder, 0.0f); diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index e9728e636f0ee032416b2da17a3ea83c5bb18083..63e48117056dec4af603cbc85e478fcb15ad0cec 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -76,7 +76,9 @@ XLA_TEST_F(Bfloat16Test, NegateScalarF16) { error_spec_); } -XLA_TEST_F(Bfloat16Test, BatchNormTraining) { +// Disabled on interpreter since BatchNormExanper is not run by default on the +// intepreter backend. +XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormTraining)) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); @@ -110,7 +112,9 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); } -XLA_TEST_F(Bfloat16Test, BatchNormGrad) { +// Disabled on interpreter since BatchNormExanper is not run by default on the +// intepreter backend. +XLA_TEST_F(Bfloat16Test, DISABLED_ON_INTERPRETER(BatchNormGrad)) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 05d4d04034bf50c8bb840e59b28a590fce048c19..c14d279ac560db33066ae4fc68b6290f7499bb39 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -34,6 +34,7 @@ def xla_test( xla_test_library_deps = [], backends = [], blacklisted_backends = [], + real_hardware_only = False, args = [], tags = [], copts = [], @@ -108,6 +109,10 @@ def xla_test( use for that target. **kwargs: Additional keyword arguments to pass to native.cc_test. """ + + # All of the backends in all_backends are real hardware. + _ignore = [real_hardware_only] + test_names = [] if not backends: backends = all_backends diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 12c029983336cc9aed0fde4ce6881c9a00a9869e..0e99ede5d01fcfa88c54c9cbc5a6a85bf8f15ddf 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include #include #include "absl/memory/memory.h" @@ -40,8 +41,9 @@ constexpr char kInterpreter[] = "interpreter"; // Wrapper function that creates a nicer error message (than a bare // ValueOrDie()) if the platform we intend to test is not available. -Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) { - StatusOr result = +LocalClient* GetOrCreateLocalClientOrDie( + const LocalClientOptions& client_options) { + StatusOr result = ClientLibrary::GetOrCreateLocalClient(client_options); TF_CHECK_OK(result.status()) << " could not create local client for testing"; return result.ValueOrDie(); @@ -74,6 +76,9 @@ ClientLibraryTestBase::ClientLibraryTestBase( // default. execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); + + execution_options_.mutable_debug_options() + ->set_xla_hlo_evaluator_use_fast_path(true); } ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) @@ -88,6 +93,9 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); + + execution_options_.mutable_debug_options() + ->set_xla_hlo_evaluator_use_fast_path(true); } string ClientLibraryTestBase::TestName() const { @@ -184,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( verify_output(actual, ""); // Try with all output layouts. - std::vector minor_to_major(ShapeUtil::Rank(expected.shape())); + std::vector minor_to_major(expected.shape().rank()); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto layout = ShapeUtil::MakeShapeWithLayout( @@ -217,7 +225,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_ASSIGN_OR_RETURN(auto literal, client_->Transfer(*arguments[index], nullptr)); // Skip tuples because they don't have a rank. - if (ShapeUtil::IsTuple(literal.shape())) { + if (literal.shape().IsTuple()) { layout_strings.push_back( ShapeUtil::HumanStringWithLayout(literal.shape())); arguments_with_layout.push_back(arguments[index]); @@ -227,7 +235,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return Status::OK(); } - std::vector minor_to_major(ShapeUtil::Rank(literal.shape())); + std::vector minor_to_major(literal.shape().rank()); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto literal_relayout = @@ -273,9 +281,10 @@ StatusOr ClientLibraryTestBase::ComputeAndTransfer( if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - owning_arguments.push_back( - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) - .ValueOrDie()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr owned_argument, + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } @@ -296,9 +305,10 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - owning_arguments.push_back( - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) - .ValueOrDie()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr owned_argument, + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } @@ -356,9 +366,10 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - owning_arguments.push_back( - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) - .ValueOrDie()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr owned_argument, + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 34148e5886d3806b19fc5bee90806c5678df345e..d700437ed355c144639f76d683055e211975fde9 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -76,7 +76,7 @@ class ClientLibraryTestBase : public ::testing::Test { void SetFastMathDisabled(bool disabled) { auto* opts = execution_options_.mutable_debug_options(); opts->set_xla_cpu_enable_fast_math(!disabled); - opts->set_xla_gpu_enable_fast_math(!disabled); + opts->set_xla_gpu_enable_fast_min_max(!disabled); } void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } @@ -385,8 +385,8 @@ class ClientLibraryTestBase : public ::testing::Test { StatusOr> ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments); - Client* client_; - Client* ref_client_; // To compute reference result. + LocalClient* client_; + LocalClient* ref_client_; // To compute reference result. ExecutionOptions execution_options_; private: @@ -431,7 +431,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -455,7 +456,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, @@ -480,7 +482,8 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); @@ -506,7 +509,8 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); @@ -532,7 +536,8 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 363dee74b2755a6bdc3c5a5164a85378581c21d2..247328b730f3af936d933f824da491b593b27c90 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -96,7 +96,7 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, LiteralSlice(result, {1})); - EXPECT_TRUE(ShapeUtil::IsTuple(result.shape())); + EXPECT_TRUE(result.shape().IsTuple()); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape())); EXPECT_TRUE(ShapeUtil::Equal( @@ -109,7 +109,10 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { /*minor_to_major=*/{1, 0}))); } -XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { +// Disabled for interpreter since ExecuteAsyncOnStream is not implemented on +// interpreter backend. +XLA_TEST_F(ClientTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(ExecuteParallel))) { XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 3b0414a6045a7c5f4f75948d8ccf2775c575626e..ef800b8ef624bf1020ff1e6857c13b0387482cd3 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -151,19 +151,35 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { } } -TEST_F(ComputeConstantTest, IndirectParamMissing) { +TEST_F(ComputeConstantTest, GetDimensionSize) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = - Add(ConstantR0(&b, 1.0f), - Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param")); - EXPECT_FALSE(IsConstant(computation, &b)); + auto add = + Add(ConstantR1(&b, {1.0f}), ConstantR1(&b, {1.0f})); + auto get_dimension_size = GetDimensionSize(add, 0); + EXPECT_TRUE(IsConstant(get_dimension_size, &b)); + + TF_ASSERT_OK_AND_ASSIGN(auto value, ComputeConstantScalar( + client, get_dimension_size, &b)); + EXPECT_EQ(value, 1); + } +} - auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE( - absl::StrContains(value.status().ToString(), "depends on a parameter")) - << value.status(); +TEST_F(ComputeConstantTest, MultipleGetDimensionSize) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto add = + Add(ConstantR2(&b, {{1.0f}}), ConstantR2(&b, {{1.0f}})); + auto get_dimension_size = GetDimensionSize(add, 0); + auto get_dimension_size_2 = GetDimensionSize(add, 0); + auto add_2 = Add(get_dimension_size, get_dimension_size_2); + EXPECT_TRUE(IsConstant(add_2, &b)); + + TF_ASSERT_OK_AND_ASSIGN(auto value, + ComputeConstantScalar(client, add_2, &b)); + EXPECT_EQ(value, 2); } } diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 72ff1e74a47c8584cb5336c86a1c978c4637a902..6530007871ced1d0bbffe2b44ccc8cf9bddd79e1 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -21,11 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -178,5 +181,54 @@ TEST_F(ConstantsTest, Token) { TF_ASSERT_OK(Execute(&builder, {}).status()); } +TEST_F(ConstantsTest, FullLike) { + XlaBuilder b(TestName()); + auto val1 = Iota(&b, F32, 3); + auto val2 = FullLike(val1, 10); + val1 + val2; + ComputeAndCompareR1(&b, {10, 11, 12}, {}, error_spec_); +} + +TEST_F(ConstantsTest, IllegalFullLikeOnTuple) { + XlaBuilder b(TestName()); + auto tuple = Tuple(&b, {Iota(&b, F32, 3), Iota(&b, F32, 1)}); + FullLike(tuple, 10); // Illegal; can't do FullLike on a tuple. + EXPECT_FALSE(b.Build().ok()); +} + +TEST_F(ConstantsTest, FullLikeScalar) { + XlaBuilder b(TestName()); + auto scalar1 = ConstantR0WithType(&b, F32, 1); + auto scalar2 = FullLike(scalar1, 2); + scalar1 - scalar2; + ComputeAndCompareR0(&b, -1, {}, error_spec_); +} + +class ConstantsHloTest : public HloTestBase {}; + +// TODO(b/121147351): Fails on GPU. Not clear if this is expected behavior. +XLA_TEST_F(ConstantsHloTest, DISABLED_ON_GPU(BitcastOfConstant)) { + const char* testcase = R"( + HloModule module, is_scheduled=true + + func { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] add(lhs, rhs) + } + + ENTRY test { + constant.0 = s32[1]{0} constant({0}) + parameter.0 = s32[] parameter(0) + constant-as-scalar = s32[] bitcast(constant.0) + ROOT result = s32[] call(parameter.0, constant-as-scalar), to_apply=func + } + )"; + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); + auto param = LiteralUtil::CreateR0(1); + auto result = ExecuteNoHloPasses(std::move(module), {¶m}); + EXPECT_TRUE(LiteralTestUtil::Equal(param, result)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..96f4aedf8b996b152b77628252841348e732756f --- /dev/null +++ b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -0,0 +1,155 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct BatchGroupedConvolution2DSpec { + int64 output_batch, window, window_dilation; + std::vector activation_dims; + std::vector kernel_dims; + std::vector output_dims; +}; + +class BatchGroupedConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + std::vector> config_options = { + {8, 5, 3, 2}, {4, 5, 5, 2}, {8, 7, 4, 128}, {16, 20, 20, 256}, + {256, 7, 5, 4}, {256, 6, 6, 4}, {256, 8, 8, 512}}; + + for (auto option : config_options) { + int64 feature = option[3]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[0]; + + BatchGroupedConvolution2DSpec config; + config.window_dilation = 1; + config.output_batch = feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, feature}; + + config.kernel_dims = {batch, kernel_size, kernel_size, feature}; + + int64 output_space_size = 3 + activation_size - kernel_size; + config.output_dims = {output_space_size, output_space_size, feature, 1}; + + config_set.push_back(config); + + // Add configurations for window dilation cases. + if (activation_size % 2 == 0 && activation_size == kernel_size) { + BatchGroupedConvolution2DSpec config; + config.window_dilation = 2; + config.output_batch = feature; + config.window = kernel_size / 2; + config.activation_dims = {batch, activation_size, activation_size, + feature}; + config.kernel_dims = {batch, kernel_size / 2, kernel_size / 2, feature}; + + int64 output_space_size = 5; + config.output_dims = {output_space_size, output_space_size, feature, 1}; + + config_set.push_back(config); + } + } + + return config_set; +} + +string BatchGroupedConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), data_type); + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextBatchGroupedConvolution2D( + const BatchGroupedConvolution2DSpec& spec, bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s] parameter(0) + kernel = %s[%s] parameter(1) + ROOT conv = %s[%s] convolution(%s[%s] activation, %s[%s] kernel), + window={size=%dx%d pad=1_%dx1_%d rhs_dilate=%dx%d}, dim_labels=f01b_i01o->01fb, + batch_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), data_type, + absl::StrJoin(spec.output_dims, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), spec.window, spec.window, + spec.window_dilation, spec.window_dilation, spec.window_dilation, + spec.window_dilation, spec.output_batch); +} + +XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) { + const BatchGroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = + BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + BatchGroupedConvolution2DTestWithRandomIndices, + BatchGroupedConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + BatchGroupedConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc index 60ce576ceb20b89b59e72d821e63b0ccdee51b0b..627a17a0ca114085240dbaf28211bb3511cf0cab 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc @@ -50,9 +50,9 @@ class DepthwiseConvolution2DTest static std::vector GetConv2DTestCases() { std::vector config_set; std::vector> config_options = { - {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64}, - {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {128, 1, 2, 144}, - {256, 1, 2, 64}, {64, 14, 12, 172}, {16, 9, 4, 16}}; + {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64}, + {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {64, 14, 12, 172}, + {16, 9, 4, 16}, {128, 1, 2, 144}, {256, 1, 2, 64}}; for (auto option : config_options) { int64 feature = option[0]; @@ -136,7 +136,7 @@ string BuildHloTextDepthwiseConvolution2D( if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv, is_scheduled=true + HloModule TensorFlowDepthwiseConv ENTRY main { activation = %s[%s]{%s} parameter(0) @@ -161,7 +161,7 @@ string BuildHloTextDepthwiseConvolution2D( } else if (spec.stride == -1) { return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv, is_scheduled=true + HloModule TensorFlowDepthwiseConv ENTRY main { activation = %s[%s]{%s} parameter(0) @@ -185,7 +185,7 @@ string BuildHloTextDepthwiseConvolution2D( } else { return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv, is_scheduled=true + HloModule TensorFlowDepthwiseConv ENTRY main { activation = %s[%s]{%s} parameter(0) @@ -215,13 +215,13 @@ XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) { const string hlo_text = BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16); - EXPECT_TRUE(RunAndCompareNoHloPasses( - hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status { - BFloat16MixedPrecisionRemoval remover; - TF_RETURN_IF_ERROR(remover.Run(module).status()); - Despecializer despecializer; - return despecializer.Run(module).status(); - })); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); } INSTANTIATE_TEST_CASE_P( diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 4a58a1ed66c438d1dd9561f4eb029b38d8c6cbdd..9db9f2563b636c4f929585eb13a9c7f809833eda 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -98,7 +98,7 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { precision.add_operand_precision(PrecisionConfig::HIGHEST); precision.add_operand_precision(PrecisionConfig::DEFAULT); Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1, - &precision); + /*batch_group_count=*/1, &precision); ComputeAndCompare(&builder, {}, error_spec_); } @@ -467,8 +467,8 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { // servers. The error message is missing the operator ++. template void iota_int_init_value(std::vector& values, int init_value) { - std::for_each(values.begin(), values.end(), - [&](T& value) { value = static_cast(init_value++); }); + absl::c_for_each(values, + [&](T& value) { value = static_cast(init_value++); }); } template diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 3622f2c1e84639baed13059b21b20609d1347da6..df005a67097bb8aaf070c57d1c51acd1909fee12 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -133,7 +133,9 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); - literal_layout->mutable_minor_to_major()->SwapElements(0, 1); + // Swap the first and second elements. + *literal_layout->mutable_minor_to_major() = { + literal_layout->minor_to_major(1), literal_layout->minor_to_major(0)}; HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 738b6442354b01364278e3e3c713aa2cdb5cf47d..cad43d1b5547d74701760fa623e50466fc15c263 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -54,11 +54,20 @@ void Add1ToValues(float* out, float** in) { out[2] = array[2] + 1; out[3] = array[3] + 1; } + +void F32TupleSwap(float** out, float** in) { + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[0], sizeof(float)); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[1], sizeof(float)); + *out[0] = *in[1]; + *out[1] = *in[0]; +} + } // namespace REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); +REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap); namespace xla { namespace { @@ -69,7 +78,7 @@ class CustomCallTest : public HloTestBase { Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2}); }; -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { +XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) { auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -84,7 +93,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { +XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) { auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -105,7 +114,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { +XLA_TEST_F(CustomCallTest, UsedInOtherComputations) { auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); @@ -129,7 +138,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { +XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) { auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); @@ -151,7 +160,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); } -XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { +XLA_TEST_F(CustomCallTest, LayoutConstrained) { // The argument and result of the computation are set to different layouts, // but the custom call is layout constrained to a fixed operand and result // layout, so the correct result should be produced. @@ -176,6 +185,26 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); } +XLA_TEST_F(CustomCallTest, TupleOutput) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT %custom-call = (f32[], f32[]) custom-call(f32[] %p0, f32[] %p1), custom_call_target="F32TupleSwap", operand_layout_constraints={f32[], f32[]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal arg0 = LiteralUtil::CreateR0(7.f); + Literal arg1 = LiteralUtil::CreateR0(42.f); + + Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0}); + Literal result = ExecuteAndTransfer(std::move(module), {&arg0, &arg1}); + EXPECT_EQ(result, expected); +} + class CustomCallClientAPITest : public ClientLibraryTestBase {}; // When using the client API, CustomCall targets can't begin with '$' -- these diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 25091b8d5d5498edf3ce86efe225cd0e2fd8ff6b..f740f4815810727890583405b2244fceaec0bd3f 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -19,18 +19,19 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" namespace xla { namespace { @@ -919,8 +920,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto one = ConstantR0(&builder, 1); + auto zero = ConstantR0(&builder, 0); + auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -946,8 +948,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -975,8 +978,9 @@ XLA_TEST_F(DotOperationTest, XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1002,8 +1006,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1034,8 +1039,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1066,8 +1072,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {0, 1}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); @@ -1090,8 +1097,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -1114,8 +1122,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) { XlaBuilder builder(TestName()); auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); - auto start_constant = ConstantR1(&builder, {1, 0}); - auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -1148,5 +1157,105 @@ XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { ComputeAndCompareR2(&builder, expected, {}, error_spec_); } + +using EinsumParamType = + std::tuple, std::vector, string>; +class EinsumTest : public DotOperationTest, + public ::testing::WithParamInterface {}; +XLA_TEST_P(EinsumTest, SimpleEinsumTest) { + XlaBuilder builder(TestName()); + auto x = AddParam( + MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam()))) + .ValueOrDie(), + &builder); + auto y = AddParam( + MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam()))) + .ValueOrDie(), + &builder); + Einsum(x, y, std::get<2>(GetParam())); + ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3}); +} + +std::vector GetEinsumTestCases() { + using v = std::vector; + using p = EinsumParamType; + std::vector

test_cases = { + p{v{5, 6}, v{6, 7}, "mk,kn->mn"}, + p{v{5, 6}, v{6, 7}, "mk,kn->nm"}, + p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn->nmB"}, + p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->nmB"}, + p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->Bnm"}, + p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"}, + p{v{5, 6}, v{6, 7}, "ab,cd->dcba"}, + p{v{6}, v{6, 7}, "b,bc->c"}, + }; + return test_cases; +} + +INSTANTIATE_TEST_CASE_P(Einsum, EinsumTest, + ::testing::ValuesIn(GetEinsumTestCases())); + +class DotOperationTextTest : public HloTestBase {}; + +XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) { + absl::string_view hlo_string = + R"( +HloModule ComplexDotMultipleNonContracting + +ENTRY %test { + %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0) + %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1) + ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) { + absl::string_view hlo_string = + R"( +HloModule ComplexDotMultipleNonContracting + +ENTRY %test { + %lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0) + %rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1) + ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DotWithNoDnums) { + absl::string_view hlo_string = + R"( +HloModule DotWithNoDnums + +ENTRY %test { + %lhs = f32[2,3]{1,0} parameter(0) + %rhs = f32[4,5]{1,0} parameter(1) + ROOT %dot = f32[2,3,4,5]{3,2,1,0} dot(%lhs, %rhs) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, Einsum) { + absl::string_view hlo_string = + R"( +HloModule Einsum + +ENTRY %test { + %lhs = f32[8,64,96]{2,1,0} parameter(0) + %rhs = f32[96,32,4]{2,1,0} parameter(1) + ROOT %dot = f32[8,64,32,4]{3,2,1,0} dot(%lhs, %rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 7501c6d957e7afe99b8c530e5f0d575f818367da..82e2db36143b2552472fedae701f32389a9be108 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -135,11 +135,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::unique_ptr start_data = CreateR0Parameter( + slice_starts[0], 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); - DynamicSlice(input, starts, slice_sizes); + DynamicSlice(input, absl::Span({starts}), slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -160,14 +160,23 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(2); + std::vector> start_data(2); + for (int i = 0; i < 2; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } + // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } template @@ -186,14 +195,22 @@ class DynamicSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(3); + std::vector> start_data(3); + for (int i = 0; i < 3; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } }; @@ -372,16 +389,12 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { .ValueOrDie()); XlaBuilder builder(TestName()); - // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_value); auto update = ConstantLiteral(&builder, update_value); - DynamicUpdateSlice(input, update, starts); + DynamicUpdateSlice(input, update, absl::Span({})); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_value, {}); } template @@ -405,12 +418,12 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::unique_ptr start_data = CreateR0Parameter( + slice_starts[0], 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); auto update = ConstantLiteral(&builder, update_values); - DynamicUpdateSlice(input, update, starts); + DynamicUpdateSlice(input, update, absl::Span({starts})); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -435,15 +448,23 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(2); + std::vector> start_data(2); + for (int i = 0; i < 2; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); auto update = ConstantLiteral(&builder, update_values); DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } template @@ -466,15 +487,24 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - XlaOp starts; - std::unique_ptr start_data = CreateR1Parameter( - slice_starts, 0, "slice_starts", &builder, &starts); + std::vector starts(3); + std::vector> start_data(3); + for (int i = 0; i < 3; ++i) { + start_data[i] = CreateR0Parameter( + slice_starts[i], i, "slice_starts", &builder, &starts[i]); + } + // Build dynamic slice computation. auto input = ConstantLiteral(&builder, input_values); auto update = ConstantLiteral(&builder, update_values); DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); + std::vector argument_ptrs; + absl::c_transform(start_data, std::back_inserter(argument_ptrs), + [](const std::unique_ptr& argument) { + return argument.get(); + }); + ComputeAndCompareLiteral(&builder, expected_values, argument_ptrs); } template @@ -518,8 +548,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaOp update; std::unique_ptr update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); - auto starts = ConstantR1(&builder, {index, 0, 0}); - DynamicUpdateSlice(input, update, starts); + auto constant_index = ConstantR0(&builder, index); + auto zero = ConstantR0(&builder, 0); + DynamicUpdateSlice(input, update, {constant_index, zero, zero}); // Run computation and compare against expected values. ComputeAndCompareR3(&builder, expected_values, @@ -720,46 +751,55 @@ void BM_DynamicSlice(int num_iters) { {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); auto input = ConstantLiteral(&builder, input_literal); + auto stream = + client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); + // Create dynamic slice start indices as a parameter: shape [4] - auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); - auto start_indices = - Parameter(&builder, 0, start_indices_shape, "start_indices"); + auto start_indices_shape = ShapeUtil::MakeShape(S32, {}); + std::vector start_indices(4); + std::vector shaped_buffers; + std::vector host_shapes(4); + for (int i = 0; i < 4; ++i) { + start_indices[i] = + Parameter(&builder, i, start_indices_shape, "start_indices"); + auto start_index_literal = LiteralUtil::CreateR0(i + 1); + // Initialize and transfer parameter buffer. + shaped_buffers.emplace_back( + client->backend() + .transfer_manager() + ->AllocateScopedShapedBuffer(start_indices_shape, &allocator, + /*device_ordinal=*/0) + .ConsumeValueOrDie()); + host_shapes[i] = &shaped_buffers[i].on_host_shape(); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + stream.get(), start_index_literal, shaped_buffers[i])); + } + // Add DynamicSlice op to the computatation. DynamicSlice(input, start_indices, {1, 1, 1, 1}); auto computation = builder.Build().ConsumeValueOrDie(); - // Initialize and transfer parameter buffer. - auto buffer = client->backend() - .transfer_manager() - ->AllocateScopedShapedBuffer( - start_indices_shape, &allocator, /*device_ordinal=*/0) - .ConsumeValueOrDie(); - - auto start_indices_literal = LiteralUtil::CreateR1({0, 1, 2, 3}); - auto stream = - client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - stream.get(), start_indices_literal, buffer)); - std::unique_ptr executable = - client - ->Compile(computation, {&buffer.on_host_shape()}, - ExecutableBuildOptions()) + client->Compile(computation, host_shapes, ExecutableBuildOptions()) .ConsumeValueOrDie(); // Run some warm-up executions. ExecutableRunOptions options; options.set_allocator(&allocator); const int kWarmups = 2; + std::vector shaped_buffer_ptrs; + absl::c_transform(shaped_buffers, std::back_inserter(shaped_buffer_ptrs), + [](const ScopedShapedBuffer& buffer) { return &buffer; }); + for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({&buffer}, options); + auto result = executable->Run(shaped_buffer_ptrs, options); ASSERT_TRUE(result.ok()); } // Run benchmark. tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({&buffer}, options); + auto result = executable->Run(shaped_buffer_ptrs, options); ASSERT_TRUE(result.ok()); } } diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index c84973e17b234c24c84f02a369ce0185f5772cca..b961e6102692cb3b90976d621c62cb4cf18a9b6b 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include "absl/base/casts.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -21,66 +23,166 @@ limitations under the License. namespace xla { namespace { + class ExhaustiveF32ElementwiseOpTest : public ClientLibraryTestBase, public ::testing::WithParamInterface> { protected: - ErrorSpec error_spec_{0.0001, 0.0001, /*relaxed_nans=*/true}; + ErrorSpec error_spec_{0.0001, 0.0001}; + + bool IsClose(float expected, float actual) { + float abs_err = std::abs(expected - actual); + float rel_err = abs_err / std::abs(expected); + return abs_err < error_spec_.abs || rel_err < error_spec_.rel || + (std::isnan(expected) && std::isnan(actual)) || + (std::isinf(expected) && std::isinf(actual) && + (expected > 0) == (actual > 0)); + } template void ExhaustivelyTestF32Op(EnqueueOpTy enqueue_op, float (*evaluate_op)(float), std::pair known_incorrect_range) { + SetFastMathDisabled(true); + int64 begin, end; std::tie(begin, end) = GetParam(); int64 input_size = end - begin; + + if (begin >= known_incorrect_range.first && + end <= known_incorrect_range.second) { + LOG(INFO) << absl::StreamFormat( + "Skipping this shard, as the range under test, [%d, %d), falls " + "entirely within the known-incorrect range [%d, %d).", + begin, end, known_incorrect_range.first, + known_incorrect_range.second); + return; + } + LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; XlaBuilder builder(TestName()); - Literal input_literal = - LiteralUtil::CreateFromDimensions(F32, {input_size}); - for (int64 i = begin; i < end; i++) { + auto ith_input_elem = [&](int64 i) -> float { + i += begin; + // If the operation is known to be buggy on a specific input clamp that + // input to 0 under the assumption that the op is at least correct on 0. if (i >= known_incorrect_range.first && i < known_incorrect_range.second) { - // If the operation is known to be buggy on a specific input clamp that - // input to 0 under the assumption that the op is at least correct on 0. - input_literal.Set({i - begin}, 0.0f); - } else { - input_literal.Set({i - begin}, absl::bit_cast(i)); + return 0; } - } - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(input_literal)); + return absl::bit_cast(i); + }; + Literal input_literal = + LiteralUtil::CreateFromDimensions(F32, {input_size}); + absl::Span input_arr = input_literal.data(); + for (int64 i = 0; i < input_size; i++) { + input_arr[i] = ith_input_elem(i); + } auto input = Parameter(&builder, 0, input_literal.shape(), "input"); enqueue_op(&builder, input); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); + + // Build and run the computation using the LocalClient API, rather than the + // plain Client API, which is used by ClientLibraryTestBase. This is + // because the plain Client API results does more memcpys to/from Literals, + // and that's slow given that we're touching a lot of data here. + // + // Copy debug options from ClientLibraryTestBase. In particular, we're + // interested in disabling constant folding. + ExecutableBuildOptions build_opts; + *build_opts.mutable_debug_options() = *mutable_debug_options(); + TF_ASSERT_OK_AND_ASSIGN( + auto executable, + client_->Compile(comp, {&input_literal.shape()}, build_opts)); + + TF_ASSERT_OK_AND_ASSIGN( + ScopedShapedBuffer input_data, + client_->LiteralToShapedBuffer(input_literal, /*device_ordinal=*/0)); + + ExecutableRunOptions run_opts; + run_opts.set_allocator(client_->backend().memory_allocator()); + run_opts.set_intra_op_thread_pool( + client_->backend().eigen_intra_op_thread_pool_device()); + TF_ASSERT_OK_AND_ASSIGN(ScopedShapedBuffer result, + executable->Run({&input_data}, run_opts)); + + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + client_->ShapedBufferToLiteral(result)); + + // We essentially reimplement LiteralTestUtil::Near here because + // a) this streamlined implementation is much faster, and + // b) we can print out better error messages (namely, we can print out + // which floating-point value input failed, while LiteralTestUtil::Near + // can only print out the input index that failed). + // c) we need special handling of certain inputs. For example, we say that + // a denormal input has multiple correct outputs (namely, f(x) and f(0)) + // and just needs to be close to one of them. + absl::Span result_arr = result_literal.data(); + ASSERT_EQ(result_arr.size(), input_arr.size()); + int64 mismatches = 0; + // Hoisting this out of the loop is a nice speedup on shards that have many + // denormals. + const float expected_at_zero = evaluate_op(0); + for (int64 i = 0; i < input_arr.size(); ++i) { + float input = ith_input_elem(i); + float actual = result_arr[i]; + float expected = evaluate_op(input); + if (IsClose(expected, actual)) { + continue; + } - std::vector expected_result; - expected_result.reserve(input_size); - for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(evaluate_op(input_literal.Get({i}))); - } + constexpr int64 kMaxMismatchesPrinted = 1000; + if (std::fpclassify(input) == FP_SUBNORMAL) { + // For denormal inputs, we accept answers that are close to either + // - evaluate_op(input) OR + // - evaluate_op(0). + if (IsClose(expected_at_zero, actual)) { + continue; + } + ++mismatches; + if (mismatches < kMaxMismatchesPrinted || VLOG_IS_ON(2)) { + // Use %0.9g because that's guaranteed to print an f32 to full + // precision. + LOG(ERROR) << absl::StreamFormat( + "Mismatch on denormal value %0.9g (0x%08x). Expected either " + "%0.9g (0x%08x) (evaluated at true value) or %0.9g (0x%08x) " + "(evaluated at zero), but got %0.9g (0x%08x).", + input, absl::bit_cast(input), // + expected, absl::bit_cast(expected), // + expected_at_zero, absl::bit_cast(expected_at_zero), + actual, absl::bit_cast(actual)); + } + } else { + mismatches++; + if (mismatches < kMaxMismatchesPrinted || VLOG_IS_ON(2)) { + LOG(ERROR) << absl::StreamFormat( + "Mismatch on %0.9g (0x%08x). Expected %0.9g (0x%08x), but got " + "%0.9g (0x%08x).", + input, absl::bit_cast(input), // + expected, absl::bit_cast(expected), // + actual, absl::bit_cast(actual)); + } + } - ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, - error_spec_); + if (mismatches == kMaxMismatchesPrinted && !VLOG_IS_ON(2)) { + LOG(ERROR) << "Not printing any more mismatches; pass " + "--vmodule=exhaustive_f32_elementwise_op_test=2 to see " + "all of them."; + } + } + EXPECT_EQ(mismatches, 0); } }; XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { -#ifdef XLA_TEST_BACKEND_CPU - // TODO(b/73141998): The vectorized Log implementation gives results outside - // our error spec in this range (these numbers are bitwise representations of - // floats expressed as a zero extended int64). - std::pair known_incorrect_range = {1, 8388608}; -#else - std::pair known_incorrect_range = {0, 0}; +#if !defined(XLA_TEST_BACKEND_CPU) && !defined(XLA_TEST_BACKEND_GPU) + error_spec_ = ErrorSpec{0.001, 0.001}; #endif - ExhaustivelyTestF32Op( [](XlaBuilder* builder, const XlaOp& input) { Log(input); }, std::log, - known_incorrect_range); + /*known_incorrect_range=*/{0, 0}); } XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { @@ -105,6 +207,18 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { /*known_incorrect_range=*/{0, 0}); } +XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ErfF32) { + ExhaustivelyTestF32Op( + [](XlaBuilder* builder, const XlaOp& input) { Erf(input); }, std::erf, + /*known_incorrect_range=*/{0, 0}); +} + +XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ErfcF32) { + ExhaustivelyTestF32Op( + [](XlaBuilder* builder, const XlaOp& input) { Erfc(input); }, std::erfc, + /*known_incorrect_range=*/{0, 0}); +} + std::vector> CreateExhaustiveParameters() { // We break up the 2^32-element space into small'ish chunks to keep peak // memory usage low. diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index dcb469087e0064d17ce3b04fdeaf0b6136069a55..1b0bebe2d03a9a153cd0c80329ed0c49c91333a3 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -48,7 +48,7 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { tensorflow::SubProcess file_check_process; file_check_process.SetProgram(file_check_path, - {file_check_path, pattern_path}); + {file_check_path, "-v", pattern_path}); file_check_process.SetChannelAction(tensorflow::CHAN_STDIN, tensorflow::ACTION_PIPE); file_check_process.SetChannelAction(tensorflow::CHAN_STDERR, @@ -71,9 +71,7 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { LOG(WARNING) << "NOTE: FileCheck binary does not exist!"; } - LOG(WARNING) << "FileCheck error: " << standard_error; - LOG(WARNING) << "FileCheck input was:"; - XLA_LOG_LINES(tensorflow::WARNING, input); + LOG(WARNING) << "FileCheck error:\n" << standard_error; LOG(WARNING) << "FileCheck pattern was:"; XLA_LOG_LINES(tensorflow::WARNING, pattern); } else if (!standard_error.empty()) { diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index d1fddf9d6b494a822610e41307fa103dc90bdef3..2178c9b3f3d39ac034c59585c6836d2bc59162c1 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -523,10 +523,10 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto dynamic_slice2 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - ShapeUtil::MakeShape(S32, {2}), const0, const1, {2})); + ShapeUtil::MakeShape(S32, {2}), const0, {const1}, {2})); auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, dynamic_slice2)); hlo_module->AddEntryComputation(builder.Build()) diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index daa89398a697af9149797d621c3bdca80a00aedd..d65b67a535d43553a3a94f76482ad4618f9b8aab 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -600,7 +600,9 @@ ENTRY main { class GatherClientLibraryTest : public ClientLibraryTestBase {}; -XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { +// Disabled on interpreter since ExectuteAsyncOnStream is not supported. +XLA_TEST_F(GatherClientLibraryTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(Basic))) { // We create this HLO, but using the XlaBuilder API. // // ENTRY main { diff --git a/tensorflow/compiler/xla/tests/grouped_convolution_test.cc b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f7049910e70c4e591636a47c1b6ba72cf2c234f --- /dev/null +++ b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc @@ -0,0 +1,245 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct GroupedConvolution2DSpec { + int64 input_feature, output_feature, window, stride, pad, lhs_dilate; + int64 group_size, group_count; + std::vector activation_dims; + std::vector activation_layout; + std::vector kernel_dims; + std::vector kernel_layout; + std::vector output_dims; + std::vector output_layout; +}; + +class GroupedConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + // Add to this set if you want a new test configuration. + // Rule : the penultimate number must be divisible by the last number. + std::vector> config_options = {{8, 2, 2, 1, 1024, 128}, + {512, 3, 3, 144, 1024, 16}, + {256, 3, 3, 129, 512, 64}, + {64, 1, 2, 127, 32, 8}, + {256, 3, 3, 256, 1024, 4}}; + + for (auto option : config_options) { + int64 output_feature = option[0]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[3]; + int64 input_feature = option[4]; + int64 group_size = option[5]; + + std::vector kernel_layout = {3, 2, 1, 0}; + GroupedConvolution2DSpec config; + config.group_size = group_size; + config.group_count = input_feature / group_size; + config.output_feature = output_feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, + input_feature}; + config.activation_layout = {3, 0, 2, 1}; + + config.kernel_dims = {kernel_size, kernel_size, group_size, output_feature}; + config.kernel_layout = {3, 2, 1, 0}; + + if (activation_size == 1 && kernel_size == 2) { + // Test for outer dim. + config.output_dims = {batch, activation_size + kernel_size - 1, + activation_size + kernel_size, output_feature}; + } else if (output_feature == 256) { + // Restrict dilation-based tests only to one feature configuration. + config.stride = activation_size - 1; + config.pad = 0; + config.lhs_dilate = output_feature / 32; + config.output_dims = {batch, output_feature / 32, + activation_size - kernel_size + 1, output_feature}; + } else { + config.stride = config.pad = config.lhs_dilate = -1; + config.output_dims = {batch, activation_size - kernel_size + 1, + activation_size - kernel_size + 1, output_feature}; + } + + // Try this layout for all kernel shapes. + config.output_layout = {3, 0, 2, 1}; + config_set.push_back(config); + + // Try other layouts only for certain kernel shapes. + if (kernel_size % 2 == 0) { + config.activation_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.output_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.activation_layout = {3, 0, 2, 1}; + config_set.push_back(config); + } + } + + return config_set; +} + +string GroupedConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", + absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), "_output_layout_", + absl::StrJoin(spec.output_layout, "_"), data_type); + // -1 indicates non-existence. + if (spec.stride != -1) { + absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1"); + } + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextGroupedConvolution2D(const GroupedConvolution2DSpec& spec, + bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { + // Check for outer dim. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.window, spec.window, spec.window, spec.group_count); + + } else if (spec.stride == -1) { + // Check for basic, non-dilated cases. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.group_count); + } else { + // Check for base dilations. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, + dim_labels=b01f_01io->b01f, feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.stride, 0, 0, spec.lhs_dilate, spec.group_count); + } +} + +XLA_TEST_P(GroupedConvolution2DTest, DoIt) { + const GroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = BuildHloTextGroupedConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + GroupedConvolution2DTestWithRandomIndices, GroupedConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + GroupedConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 989a7c705a8254f99e5cc0e97dfde5942f146964..66f72ba8d20b8ef1f436da4425b2bb6518ee9a94 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -139,7 +139,8 @@ std::unique_ptr HloTestBase::CreateNewVerifiedModule( const string& name) { return absl::make_unique( name, GetModuleConfigForTest(), verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); + allow_mixed_precision_in_hlo_verifier_, + backend().compiler()->ShapeSizeBytesFunction()); } StatusOr> @@ -147,7 +148,8 @@ HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, const HloModuleConfig& config) { auto module = absl::make_unique( TestName(), config, verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); + allow_mixed_precision_in_hlo_verifier_, + backend().compiler()->ShapeSizeBytesFunction()); TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); TF_RETURN_IF_ERROR(module->Verify()); return std::move(module); @@ -181,6 +183,7 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() { // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + debug_options.set_xla_hlo_evaluator_use_fast_path(true); return debug_options; } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 1d1e7f437296a7493ef7da07039fcf6d273f35bc..69a4f96288c7285010e9adbdc33f1b394f58d8d2 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -46,10 +46,12 @@ class VerifiedHloModule : public HloModule { public: VerifiedHloModule(const string& name, const HloModuleConfig& config, bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function shape_size_function) : HloModule(name, config), - verifier_(verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} + verifier_( + verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier, + /*instruction_can_change_layout_func=*/{}, shape_size_function) {} ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 65205f53ddc582ae477d67705f161fef1e31b857..37b2c635eebe57590e1ba73c62f015ccf399b548 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -80,7 +80,7 @@ TEST_P(IotaR2Test, DoIt) { } INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test, - ::testing::Combine(::testing::Values(F32, S32), + ::testing::Combine(::testing::Values(F32, S32, BF16), ::testing::Range(/*start=*/10, /*end=*/1001, /*step=*/10), diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 554eb24d44168caa7d7252015e3d99f2d567df9b..a2fd6070731943f15c773265f428b16f520d02ee 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -86,7 +86,7 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, /* static */ ::testing::AssertionResult LiteralTestUtil::Near( const LiteralSlice& expected, const LiteralSlice& actual, - const ErrorSpec& error_spec, bool detailed_message) { + const ErrorSpec& error_spec, absl::optional detailed_message) { return StatusToAssertion(literal_comparison::Near( expected, actual, error_spec, detailed_message, &OnMiscompare)); } @@ -97,7 +97,8 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, if (error.has_value()) { VLOG(1) << "Expects near"; return StatusToAssertion(literal_comparison::Near( - expected, actual, *error, /*detailed_message=*/false, &OnMiscompare)); + expected, actual, *error, /*detailed_message=*/absl::nullopt, + &OnMiscompare)); } VLOG(1) << "Expects equal"; return StatusToAssertion(literal_comparison::Equal(expected, actual)); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 43cca91f64b2c0fbfde5054a361cf0f95302c23d..d7cf9bed98a3eb7479b6deb6838dc388a0869360 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -93,7 +93,7 @@ class LiteralTestUtil { static ::testing::AssertionResult Near( const LiteralSlice& expected, const LiteralSlice& actual, const ErrorSpec& error_spec, - bool detailed_message = false) TF_MUST_USE_RESULT; + absl::optional detailed_message = absl::nullopt) TF_MUST_USE_RESULT; // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values. diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index b6f9b8156b51144e4f74d285b1e4111d098f13c2..ea9b3037cf482e41238413179888f125822d161c 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -89,11 +89,11 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { Literal literal = Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal.ToString()); + EXPECT_EQ("f32[] 2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal.ToString()); + EXPECT_EQ("f32[] 4", literal.ToString()); } else if (result.find("mismatches") != string::npos) { - EXPECT_EQ("true", literal.ToString()); + EXPECT_EQ("pred[] true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } @@ -105,9 +105,9 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto actual = LiteralUtil::CreateR1({4, 5, 6}); ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual); EXPECT_THAT(result.message(), - ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); + ::testing::HasSubstr("Expected literal:\ns32[3] {1, 2, 3}")); EXPECT_THAT(result.message(), - ::testing::HasSubstr("Actual literal:\n{4, 5, 6}")); + ::testing::HasSubstr("Actual literal:\ns32[3] {4, 5, 6}")); } TEST(LiteralTestUtilTest, NearComparatorR1) { diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index a99b43f4690b3063f76e2cda1e58c9b4ba9a1df4..96527886b718bc1ea4ce8cc2d7dbeb2e3ef1d1eb 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -205,7 +205,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); - EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_TRUE(result.on_host_shape().IsTuple()); EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); Literal result_literal = ShapedBufferToLiteral(result); @@ -233,7 +233,7 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); - EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_TRUE(result.on_host_shape().IsTuple()); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); Literal result_literal = ShapedBufferToLiteral(result); @@ -311,7 +311,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer}); - EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_TRUE(result.on_host_shape().IsTuple()); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); Literal result_literal = ShapedBufferToLiteral(result); @@ -842,7 +842,8 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { LiteralUtil::CreateR0(123456789000LL)})); } -XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { +// Disabled on interpreter backend since infeed HLO is unsupported. +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); auto in = Infeed(&builder, shape); @@ -867,7 +868,8 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } -XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { +// Disabled on interpreter backend since infeed/outfeed HLOs are unsupported. +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); auto in = Infeed(&builder, shape); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 3f5135438fc59bea98527b1be30ee49339edd455..1fd9cb055c0bebc0f31496eb82f53a7b7a6cbfba 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -208,9 +208,7 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned( LiteralUtil::MakeTupleOwned( LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), @@ -241,9 +239,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { const = f32[4] constant({0, 0, 0, 0}) ROOT select = f32[4] select(gte0, gte1, const) })"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, result); @@ -273,9 +269,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { p1 = f32[3] parameter(0) ROOT map = f32[3] map(p1), to_apply=map_computation })"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, result); @@ -315,9 +309,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -346,9 +338,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -378,9 +368,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -410,9 +398,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -443,9 +429,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -478,9 +462,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -513,9 +495,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); auto init1 = LiteralUtil::CreateR0(5); @@ -549,9 +529,7 @@ XLA_TEST_F(MultiOutputFusionTest, ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p), kind=kInput, calls=fused_reduce })"); - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 8f2c26f0eea9c7a3b33cd77e5977924c1659535a..e49bcf26bd6e50f8fb36c86f217907b5d4901eae 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -80,7 +80,9 @@ XLA_TEST_F(PrngTest, LargeU01) { UniformTest(0, 1, {0x100, 0x100}); } XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest(5, 24, {12}); } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) { +// TODO(b/122047800): Interpreter does not support BF16 for RNG ops. +XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER( + DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests)))) { for (int64 seed = 0; seed < 100; ++seed) { // The largest negative number smaller than zero in bf16 that's not // denormalized. @@ -103,7 +105,9 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) { } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { +// TODO(b/122047800): Interpreter does not support BF16 for RNG ops. +XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU( + DISABLED_ON_CPU(ScalarBF16CountTests)))) { // There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75, // they should get similar counts. bfloat16 low = static_cast(32.25); @@ -276,6 +280,39 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6)); } +// This test verifies that the two RNG instructions with the same parameters in +// the same HloComputation produces different values. +XLA_TEST_F(PrngTest, DifferentValuesForIdenticalRngNodesInSameComputation) { + // Build a U[0,1) computation. + auto build_computation = [this]() { + XlaBuilder builder(TestName()); + auto a = RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 100), + ShapeUtil::MakeShape(S32, {10})); + auto b = RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 100), + ShapeUtil::MakeShape(S32, {10})); + Tuple(&builder, {a, b}); + return builder.Build(); + }; + + ExecutionOptions execution_options = execution_options_; + execution_options.set_seed(42); + + Literal result_tuple; + { + TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); + TF_ASSERT_OK_AND_ASSIGN( + result_tuple, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options)); + } + + auto results = result_tuple.DecomposeTuple(); + ASSERT_EQ(results.size(), 2); + + EXPECT_FALSE(LiteralTestUtil::Equal(results[0], results[1])); +} + XLA_TEST_F(PrngTest, TenValuesN01) { XlaBuilder builder(TestName()); RngNormal(ConstantR0(&builder, 0), ConstantR0(&builder, 1), diff --git a/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc b/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e5d7db97e88936e7336ed02a5c7a1171254b0cf --- /dev/null +++ b/tensorflow/compiler/xla/tests/ptxas_bug_120501638.cc @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class PtxasBugTest : public HloTestBase {}; + +// Checks for a bug in ptxas, tracked as Google bug 120501638, and nvidia bug +// 2459377. We never received an explanation of what exactly was going wrong +// here in ptxas. Known-bad in ptxas 10.0.145, known-good in ptxas 10.0.249. +TEST_F(PtxasBugTest, DoIt) { + const char* const kModuleStr = R"( +HloModule test + +add_F32.14 { + lhs.15 = f32[] parameter(0) + rhs.16 = f32[] parameter(1) + ROOT add.17 = f32[] add(lhs.15, rhs.16) +} + +ENTRY testcase { + arg0.1 = f32[2,5,2]{2,1,0} parameter(0) + reshape.2 = f32[2,5,2]{2,1,0} reshape(arg0.1) + constant.3 = f32[] constant(0) + pad.4 = f32[2,6,2]{2,1,0} pad(reshape.2, constant.3), padding=0_0x0_1x0_0 + reshape.5 = f32[2,3,2,2]{3,2,1,0} reshape(pad.4) + transpose.6 = f32[2,2,3,2]{3,0,2,1} transpose(reshape.5), dimensions={2,0,1,3} + reshape.7 = f32[4,3,2]{2,1,0} reshape(transpose.6) + reshape.8 = f32[4,1,3,2]{3,2,1,0} reshape(reshape.7) + transpose.9 = f32[4,2,1,3]{1,3,2,0} transpose(reshape.8), dimensions={0,3,1,2} + convert.10 = f32[4,2,1,3]{1,3,2,0} convert(transpose.9) + constant.12 = f32[] constant(0) + pad.13 = f32[4,2,1,3]{3,2,1,0} pad(convert.10, constant.12), padding=0_0x0_0x0_0x0_0 + constant.11 = f32[] constant(0) + reduce-window.18 = f32[4,2,1,3]{3,2,1,0} reduce-window(pad.13, constant.11), + window={size=1x1x1x1}, to_apply=add_F32.14 + constant.19 = f32[] constant(1) + broadcast.20 = f32[4,2,1,3]{3,2,1,0} broadcast(constant.19), dimensions={} + divide.21 = f32[4,2,1,3]{3,2,1,0} divide(reduce-window.18, broadcast.20) + convert.22 = f32[4,2,1,3]{3,2,1,0} convert(divide.21) + transpose.23 = f32[4,1,3,2]{2,1,3,0} transpose(convert.22), dimensions={0,2,3,1} + reshape.24 = f32[4,3,2]{2,1,0} reshape(transpose.23) + reshape.25 = f32[2,2,3,2]{3,2,1,0} reshape(reshape.24) + transpose.26 = f32[2,3,2,2]{3,1,0,2} transpose(reshape.25), dimensions={1,2,0,3} + reshape.27 = f32[2,6,2]{2,1,0} reshape(transpose.26) + slice.28 = f32[2,5,2]{2,1,0} slice(reshape.27), slice={[0:2], [0:5], [0:2]} + reshape.29 = f32[2,5,2]{2,1,0} reshape(slice.28) + tuple.30 = (f32[2,5,2]{2,1,0}) tuple(reshape.29) + ROOT get-tuple-element.31 = f32[2,5,2]{2,1,0} get-tuple-element(tuple.30), index=0 +})"; + + // Create a module with the true-default flags, not the default-for-testing + // flags. In particular, true-default flags enable unrolling, whereas for + // testing we disable unrolling, and this bug doesn't trigger without + // unrolling. + HloModuleConfig config; + config.set_debug_options(DefaultDebugOptionsIgnoringFlags()); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01, 0.01})); +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 22fe4a2670e2e0e1fedc45036a1ceec19f44e42e..30e2d24184a5d399e5e058a9c4a382f57e82866f 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -607,7 +607,10 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Array4D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], param.base_bounds[3]); - input.FillRandom(0.1f, 0.1f); + // Choose a prime iota length so that each window sees a unique set of + // values. (Technically, the requirement is that the iota length is + // relatively prime to all of the dimensions involved in the reduce-window.) + input.FillRepeatedIota(0, 137); Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; @@ -623,9 +626,9 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto reducer = param.reducer; - if (use_bfloat16() && Product(param.window_bounds) > 128) { - // To avoid numerical issues, force the reducer to be kMax for large bf16 - // windows. + if (use_bfloat16()) { + // To avoid numerical issues, force the reducer to be kMax for bf16 + // inputs. reducer = kMax; } @@ -949,16 +952,16 @@ struct R3ReduceWindowTestData { /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, @@ -1001,17 +1004,19 @@ TEST_P(R3ReduceWindowTest, DoIt) { const float kInitValue = 0.0f; Array3D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2]); - input.FillRandom(0.1f, 0.1f); + // Choose a prime iota length so that each window sees a unique set of values. + // (Technically, the requirement is that the iota length is relatively prime + // to all of the dimensions involved in the reduce-window.) + input.FillRepeatedIota(0, 137); Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); auto reducer = param.reducer; if (use_bfloat16()) { input_literal = LiteralUtil::ConvertF32ToBF16(input_literal); - if (Product(param.window_bounds) > 128) { - // To avoid numerical issues, force the reducer to be kMax for large bf16 - // windows. - reducer = kMax; - } + + // To avoid numerical issues, force the reducer to be kMax for bf16 + // inputs. + reducer = kMax; } XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input"); @@ -1527,6 +1532,25 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); } +XLA_TEST_F(HloTestBase, ReduceWindowS64) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: s64[], param1: s64[]) -> s64[] { + %param0 = s64[] parameter(0) + ROOT %param1 = s64[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: s64[81,8], parameter.1: s64[]) -> s64[82,8] { + %parameter.0 = s64[81,8]{1,0} parameter(0) + %parameter.1 = s64[] parameter(1) + ROOT %reduce-window = s64[82,8]{1,0} reduce-window(s64[81,8]{1,0} %parameter.0, s64[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); +} + XLA_TEST_F(HloTestBase, ReduceWindowF16) { const string hlo_string = R"( HloModule reduce-window diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 7ca99a91635e85cd0888e59ecde31e47fec21844..80a6868485c9162d1cb0de24f0adf3f1c1d2503a 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -79,30 +79,28 @@ string PrependDisabledIfIndicated(const string& test_case_name, // heuristic to decide whether the test case should be disabled, and we // determine whether the test case should be disabled by resolving the (test // case name, test name) in a manifest file. -#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class, parent_id) \ - class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ - : public parent_class { \ - public: \ - GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ - \ - private: \ - virtual void TestBody(); \ - static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ - GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)); \ - }; \ - \ - ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)::test_info_ = \ - ::testing::internal::MakeAndRegisterTestInfo( \ - #test_case_name, \ - ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ - .c_str(), \ - nullptr, nullptr, \ - ::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \ - parent_class::SetUpTestCase, parent_class::TearDownTestCase, \ - new ::testing::internal::TestFactoryImpl); \ +#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \ + class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ + : public parent_class { \ + public: \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ + \ + private: \ + virtual void TestBody(); \ + static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)); \ + }; \ + \ + ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)::test_info_ = \ + ::testing::RegisterTest( \ + #test_case_name, \ + ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ + .c_str(), \ + nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \ + return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \ + }); \ void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() // This is identical to the TEST_F macro from "gtest", but it potentially @@ -111,9 +109,8 @@ string PrependDisabledIfIndicated(const string& test_case_name, // Per usual, you can see what tests are available via --gunit_list_tests and // choose to run tests that have been disabled via the manifest via // --gunit_also_run_disabled_tests. -#define XLA_TEST_F(test_fixture, test_name) \ - XLA_GTEST_TEST_(test_fixture, test_name, test_fixture, \ - ::testing::internal::GetTypeId()) +#define XLA_TEST_F(test_fixture, test_name) \ + XLA_GTEST_TEST_(test_fixture, test_name, test_fixture) // Likewise, this is identical to the TEST_P macro from "gtest", but // potentially disables the test based on the DISABLED_MANIFEST file. diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index eafa48ed7b8cf2bd67fe767ad36082661dbbd66e..95c89b0ba6f29c453abab88e29bca13ee006455a 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -168,7 +169,7 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, StatusOr MakeFakeLiteralInternal(const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) { - if (ShapeUtil::IsTuple(shape)) { + if (shape.IsTuple()) { std::vector elements; for (const Shape& element_shape : shape.tuple_shapes()) { TF_ASSIGN_OR_RETURN( @@ -274,16 +275,9 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -Literal MakeRandomIndex(absl::Span index_space, - std::minstd_rand0* engine) { - std::vector start_indices(index_space.size()); - if (engine != nullptr) { - for (int i = 0; i < index_space.size(); ++i) { - std::uniform_int_distribution generator(0, index_space[i]); - start_indices[i] = generator(*engine); - } - } - return LiteralUtil::CreateR1(start_indices); +Literal MakeRandomIndex(int64 index_bound, std::minstd_rand0* engine) { + std::uniform_int_distribution generator(0, index_bound); + return LiteralUtil::CreateR0(generator(*engine)); } // Use dataflow analysis on each parameter to see if there are uses that would @@ -300,8 +294,8 @@ std::vector FindConstrainedUses( HloInstruction* instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; - if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) || - (opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) { + if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) || + (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) { constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kFusion) { const HloInstruction* const to_analyze = @@ -336,7 +330,7 @@ std::vector FindConstrainedUses( StatusOr CreateLiteralForConstrainedUses( const absl::Span constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { - std::vector index_space; + int64 index_bound = INT64_MAX; bool no_duplicates = false; bool needs_constant = false; ConstantType constant_type = ConstantType::kUnknown; @@ -348,19 +342,16 @@ StatusOr CreateLiteralForConstrainedUses( const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice ? use->shape() : use->operand(1)->shape(); - const int64 rank = ShapeUtil::Rank(indexed_shape); - if (!index_space.empty()) { - TF_RET_CHECK(rank == index_space.size()); - for (int64 i = 0; i < rank; ++i) { - index_space[i] = std::min( - index_space[i], ShapeUtil::GetDimension(indexed_shape, i) - - ShapeUtil::GetDimension(slice_shape, i)); - } - } else { - index_space.resize(rank); - for (int64 i = 0; i < rank; ++i) { - index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); + const int64 first_index = + Cast(use)->first_index_operand_number(); + for (int64 operand = first_index; operand < use->operand_count(); + ++operand) { + if (use->operand(operand) == ¶m) { + index_bound = std::min( + index_bound, + ShapeUtil::GetDimension(indexed_shape, operand - first_index) - + ShapeUtil::GetDimension(slice_shape, + operand - first_index)); } } break; @@ -388,13 +379,14 @@ StatusOr CreateLiteralForConstrainedUses( } int constraint_count = 0; constraint_count += no_duplicates ? 1 : 0; - constraint_count += !index_space.empty() ? 1 : 0; + constraint_count += (index_bound != INT64_MAX) ? 1 : 0; constraint_count += needs_constant ? 1 : 0; if (constraint_count > 1) { return Unimplemented("Conflicting operand generation constraints."); } - if (!index_space.empty()) { - return MakeRandomIndex(index_space, engine); + if (index_bound != INT64_MAX) { + return MakeRandomIndex(index_bound, engine) + .Reshape(param.shape().dimensions()); } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: @@ -459,8 +451,8 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, std::unique_ptr CreateCanonicalDot(const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); - CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + CHECK_EQ(lhs->shape().rank(), 2); + CHECK_EQ(rhs->shape().rank(), 2); PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( 2, PrecisionConfig::DEFAULT); diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index e8f5d7a9a79ebddea3cb989dbe8eab90b630d5e7..591d6c19228a313f530cdae18f4be37e7b517601 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -61,11 +61,11 @@ XLA_TEST_F(TestUtilsTest, Token) { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - token = token[] parameter(0) - infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + token0 = token[] parameter(0) + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 - outfeed = token[] outfeed(infeed.data, token) - ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + outfeed = token[] outfeed(infeed.data, token0) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0) infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 infeed.1.token = token[] get-tuple-element(infeed.1), index=1 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) @@ -79,25 +79,26 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { R"(HloModule index_space_module ENTRY IndexSpace { - index_param = s32[3]{0} parameter(0) - array_param.1 = f32[123,4,789]{0,1,2} parameter(1) - array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) - dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3} - ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} + index_param.0 = s32[] parameter(0) + index_param.1 = s32[] parameter(1) + index_param.2 = s32[] parameter(2) + array_param.1 = f32[123,4,789]{0,1,2} parameter(3) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(4) + dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={1,2,3} + ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param.0, index_param.1, index_param.2), dynamic_slice_sizes={3,2,2} })") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); - ASSERT_EQ(args.size(), 3); - const Literal& index_arg = args[0]; + ASSERT_EQ(args.size(), 5); - EXPECT_EQ(index_arg.Get({0}), 0); + EXPECT_EQ(args[0].Get({}), 0); - EXPECT_GE(index_arg.Get({1}), 0); - EXPECT_LE(index_arg.Get({1}), 2); + EXPECT_GE(args[1].Get({}), 0); + EXPECT_LE(args[0].Get({}), 2); - EXPECT_GE(index_arg.Get({2}), 0); - EXPECT_LE(index_arg.Get({2}), 3); + EXPECT_GE(args[2].Get({}), 0); + EXPECT_LE(args[2].Get({}), 3); } XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { @@ -105,28 +106,29 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { R"(HloModule index_space_module ENTRY IndexSpace { - index_param = s32[3]{0} parameter(0) - array_param.1 = f32[123,4,789]{0,1,2} parameter(1) - array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) - update_param.1 = f32[1,2,3]{0,1,2} parameter(3) - update_param.2 = f32[3,2,2]{0,1,2} parameter(4) - - dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param) - ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) + index_param.0 = s32[] parameter(0) + index_param.1 = s32[] parameter(1) + index_param.2 = s32[] parameter(2) + array_param.1 = f32[123,4,789]{0,1,2} parameter(3) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(4) + update_param.1 = f32[1,2,3]{0,1,2} parameter(5) + update_param.2 = f32[3,2,2]{0,1,2} parameter(6) + + dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param.0, index_param.1, index_param.2) + ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param.0, index_param.1, index_param.2) })") .ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); - ASSERT_EQ(args.size(), 5); - const Literal& index_arg = args[0]; + ASSERT_EQ(args.size(), 7); - EXPECT_EQ(index_arg.Get({0}), 0); + EXPECT_EQ(args[0].Get({}), 0); - EXPECT_GE(index_arg.Get({1}), 0); - EXPECT_LE(index_arg.Get({1}), 2); + EXPECT_GE(args[1].Get({}), 0); + EXPECT_LE(args[0].Get({}), 2); - EXPECT_GE(index_arg.Get({2}), 0); - EXPECT_LE(index_arg.Get({2}), 3); + EXPECT_GE(args[2].Get({}), 0); + EXPECT_LE(args[2].Get({}), 3); } XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { @@ -198,5 +200,33 @@ ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,14 } } +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsR0InputToDynamicSlice) { + auto module = ParseHloString(R"( +HloModule Test + +ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] { + %parameter.1 = f32[20,20]{1,0} parameter(1) + %constant.1 = s32[1]{0} constant({0}) + %parameter.0 = s32[] parameter(0) + %bitcast.3 = s32[1]{0} bitcast(s32[] %parameter.0) + %concatenate.1 = s32[2]{0} concatenate(s32[1]{0} %constant.1, s32[1]{0} %bitcast.3), dimensions={0} + %dynamic-slice.2 = f32[20,1]{1,0} dynamic-slice(f32[20,20]{1,0} %parameter.1, s32[2]{0} %concatenate.1), dynamic_slice_sizes={20,1} + %bitcast.4 = f32[20]{0} bitcast(f32[20,1]{1,0} %dynamic-slice.2) + %dynamic-slice.3 = f32[1]{0} dynamic-slice(f32[20]{0} %bitcast.4, s32[1]{0} %bitcast.3), dynamic_slice_sizes={1} + ROOT %bitcast.5 = f32[] bitcast(f32[1]{0} %dynamic-slice.3) +} +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + EXPECT_TRUE(ShapeUtil::Equal(args[0].shape(), ShapeUtil::MakeShape(S32, {}))) + << ShapeUtil::HumanString(args[0].shape()); + EXPECT_TRUE( + ShapeUtil::Equal(args[1].shape(), ShapeUtil::MakeShape(F32, {20, 20}))) + << ShapeUtil::HumanString(args[1].shape()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 601c6b06938fef1f1ae809b33209ae59b24c70a2..b77cf38ed8e29973985406015c0a3936916ad5e6 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -214,8 +214,8 @@ ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] { %forty_two = f32[] constant(42.0) %add = f32[] add(f32[] %p0, f32[] %forty_two) - %token = token[] after-all(f32[] %add) - %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token) + %token0 = token[] after-all(f32[] %add) + %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token0) %neg = f32[] negate(f32[] %p1_after_token) ROOT %product = f32[] multiply(f32[] %add, f32[] %neg) } @@ -236,8 +236,8 @@ HloModule AddDependencyOfConstant, is_scheduled=true ENTRY %AddDependency (p0: f32[]) -> f32[] { %p0 = f32[] parameter(0) %forty_two = f32[] constant(42.0) - %token = token[] after-all(f32[] %p0) - %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token) + %token0 = token[] after-all(f32[] %p0) + %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token0) ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token) } )"; @@ -255,8 +255,8 @@ HloModule AddDependencyAsRoot, is_scheduled=true ENTRY %AddDependency (p: f32[3]) -> f32[3] { %p = f32[3] parameter(0) %neg = f32[3] negate(f32[3] %p) - %token = token[] after-all() - ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token) + %token0 = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token0) } )"; TF_ASSERT_OK_AND_ASSIGN( @@ -274,9 +274,9 @@ ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] { %p0 = f32[3] parameter(0) %p1 = f32[3] parameter(1) %forty_two = f32[] constant(42.0) - %token = token[] after-all() - %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token, f32[3] %p1, f32[] %forty_two) - %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token) + %token0 = token[] after-all() + %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token0, f32[3] %p1, f32[] %forty_two) + %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token0) %elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0 %elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2 ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2) diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 27ce243e9bd4afbdcc1fdc5b6873d4968086e459..cdf2c34fcc3cc005e84626c39c8ab301a9040529 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -176,8 +176,9 @@ XLA_TEST_F(TupleTest, AddTupleElements) { {2.f, 4.f, 6.f}, // row 0 {5.f, 7.f, 9.f}, // row 1 }); - ASSERT_TRUE(ShapeUtil::ShapeIs(vector_shape, F32, {3})); - ASSERT_TRUE(ShapeUtil::ShapeIs(matrix_shape, F32, {/*y=*/2, /*x=*/3})); + ASSERT_TRUE(ShapeUtil::Equal(vector_shape, ShapeUtil::MakeShape(F32, {3}))); + ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, + ShapeUtil::MakeShape(F32, {/*y=*/2, /*x=*/3}))); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } @@ -512,8 +513,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { class TupleHloTest : public HloTestBase {}; -// Disabled on the interpreter because bitcast doesn't exist on the interpreter. -XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { +XLA_TEST_F(TupleHloTest, BitcastAfterGTE) { const char* testcase = R"( HloModule m, is_scheduled=true @@ -525,9 +525,7 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy) } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); auto result = ExecuteNoHloPasses(std::move(module), {¶m}); @@ -555,13 +553,11 @@ XLA_TEST_F(TupleHloTest, s = (f32[2],f32[2]) tuple-select(cond, tup0, tup1) gte = f32[2] get-tuple-element(s), index=0 tuple = (f32[2]) tuple(gte) - token = token[] after-all() - ROOT outfeed = token[] outfeed(tuple, token) + token0 = token[] after-all() + ROOT outfeed = token[] outfeed(tuple, token0) } )"; - auto module = - HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) - .ValueOrDie(); + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); auto param0 = LiteralUtil::CreateR1({1, 2}); auto param1 = LiteralUtil::CreateR1({2, 3}); auto param4 = LiteralUtil::CreateR0(false); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 6d5f276e82087cedc356691b0ff08df24cec8d20..85212fa56d71088156d2f3edda17f71cdab56da2 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -861,7 +861,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Update. auto update = ConvertElementType(Broadcast(out0, {2}), F32); // Starts = iteration * 2; - auto starts = Reshape(Mul(iteration, ConstantR0(&builder, 2)), {1}); + auto starts = Mul(iteration, ConstantR0(&builder, 2)); // UpdateSlice. auto out1 = DynamicUpdateSlice(input, update, starts); @@ -901,7 +901,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Per backend the values generated can be different as the different backends // use different random number generators. // TODO(b/32240857): Extend test to verify outputs. -XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { +XLA_TEST_F(WhileTest, WhileWithPrngScalarResult) { auto v6s32 = ShapeUtil::MakeShape(S32, {6}); // Create a computation for the condition: repeat for count iterations. @@ -1146,7 +1146,7 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // while (f(result).get<0>()) { // result = result + 1; // } -XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { +XLA_TEST_F(WhileTest, WhileWithCallInsideCondition) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -1299,9 +1299,9 @@ void BM_WhileLoop(int num_iters) { auto one = ConstantR0(&builder, 1.0); auto update = Broadcast(one, {1, 1024, 1024}); // Starts = iteration * 2; - auto starts = ConstantR1(&builder, {0, 0, 0}); + auto zero = ConstantR0(&builder, 0); // UpdateSlice. - auto out1 = DynamicUpdateSlice(input, update, starts); + auto out1 = DynamicUpdateSlice(input, update, {zero, zero, zero}); Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index e57d072a0632b492b8b6e34439f4e80332b843b6..c7337e8caae8f2ee25f4b25dc22439e08d2ecc25 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -174,9 +174,8 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, exec_run_options.set_allocator(backend->memory_allocator()); exec_run_options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); - ServiceExecutableRunOptions run_options( - exec_run_options, /*borrow_stream=*/nullptr, - backend->eigen_intra_op_thread_pool()); + ServiceExecutableRunOptions run_options(exec_run_options, + /*borrow_stream=*/nullptr); std::vector args = {&lhs_arg, &rhs_arg}; TF_ASSERT_OK_AND_ASSIGN( auto execution_result, @@ -225,14 +224,17 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { line_no++; // Skip 'Execution profile for ....' + ASSERT_LT(line_no, profile_output_lines.size()); TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], /*expect_hlo=*/false, &parsed_profile_lines)); + ASSERT_LT(line_no, profile_output_lines.size()); TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], /*expect_hlo=*/true, &parsed_profile_lines)); + ASSERT_LT(line_no, profile_output_lines.size()); TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], /*expect_hlo=*/true, &parsed_profile_lines)); diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index cdde88c1359416d423685f330e9cbdf77948034f..c78ec522aa5f13556c6d4602267544694887f250 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -66,7 +67,7 @@ StatusOr TextLiteralReader::ReadAllLines() { } absl::StripAsciiWhitespace(&shape_string); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); + TF_ASSIGN_OR_RETURN(Shape shape, ParseShape(shape_string)); if (shape.element_type() != F32) { return Unimplemented( "unsupported element type for text literal reading: %s", diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 8926bbed2b54fceaaf0e6e991f0e881d35731ef4..52fee4770ab940741723514d742e998b25765f24 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -14,7 +14,7 @@ filegroup( visibility = ["//tensorflow/compiler/xla:internal"], ) -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") tf_cc_binary( name = "hex_floats_to_packed_literal", @@ -29,33 +29,6 @@ tf_cc_binary( ], ) -cc_library( - name = "dumped_computation_to_graphviz_library", - srcs = ["dumped_computation_to_graphviz.cc"], - deps = [ - "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/core:lib", - "@com_google_absl//absl/types:span", - ], -) - -tf_cc_binary( - name = "dumped_computation_to_graphviz", - deps = [ - ":dumped_computation_to_graphviz_library", - "//tensorflow/compiler/xla/service:interpreter_plugin", - ], -) - tf_cc_binary( name = "show_signature", srcs = ["show_signature.cc"], @@ -95,6 +68,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", + "//tensorflow/compiler/xla/service/gpu:outfeed_manager", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -234,3 +208,56 @@ tf_cc_binary( "//tensorflow/core:lib", ], ) + +tf_cc_test( + name = "hlo_extractor_test", + srcs = ["hlo_extractor_test.cc"], + deps = [ + ":hlo_extractor", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "hlo_extractor", + srcs = ["hlo_extractor.cc"], + hdrs = ["hlo_extractor.h"], + deps = [ + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_verifier", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_binary( + name = "interactive_graphviz", + srcs = ["interactive_graphviz.cc"], + deps = [ + ":hlo_extractor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) + +sh_test( + name = "interactive_graphviz_build_only_test", + srcs = ["interactive_graphviz_test.sh"], + data = [":interactive_graphviz"], +) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc deleted file mode 100644 index b623556468fb4a5d96be614b6c067d5a1df51a6f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Usage: dumped_computation_to_graphviz some_binary_snapshot_proto* -// -// Dumps a graphviz URL for a snapshot computation to the command line. -// -// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from -// ServiceInterface::SnapshotComputation to disk. -// -// The GraphViz URL is placed into the log stderr, whereas computation -// statistics are printed on stdout (implementation note: getting computation -// statistics is how we trigger compilation to split out a GraphViz URL). - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { -namespace tools { - -void RealMain(absl::Span args) { - Client* client = ClientLibrary::LocalClientOrDie(); - for (char* arg : args) { - HloSnapshot module; - TF_CHECK_OK( - tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - XlaComputation computation = - client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = GetDebugOptionsFromFlags(); - debug_options.set_xla_generate_hlo_graph(".*"); - ComputationStats stats = - client->GetComputationStats(computation, debug_options) - .ConsumeValueOrDie(); - fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); - } -} - -} // namespace tools -} // namespace xla - -int main(int argc, char** argv) { - std::vector flag_list; - xla::AppendDebugOptionsFlags(&flag_list); - xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_result) { - LOG(ERROR) << "\n" << usage; - return 2; - } - tensorflow::port::InitMain(argv[0], &argc, &argv); - - absl::Span args(argv, argc); - args.remove_prefix(1); // Pop off the binary name, argv[0] - xla::tools::RealMain(args); - return 0; -} diff --git a/tensorflow/compiler/xla/tools/hlo_extractor.cc b/tensorflow/compiler/xla/tools/hlo_extractor.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3ce5f99b0c2a8e9ae5446f4bedc34b678c95b96 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { +namespace { + +// Visitor that build a new HLO module with an entry computation and a root that +// is provided to the visit function. Only HLOs that are reachable from the new +// root instruction are included in the new module. +// +// The constructor allows specifying a set of boundary HLOs to prune the HLO +// graph. HLOs at the boundary are replaced with parameters. Can be nullptr +// which means no boundary, i.e. no HLOs are replaced with parameters. +class ExtractionVisitor : public ConstDfsHloVisitorWithDefault { + public: + explicit ExtractionVisitor( + const HloModule& old_module, + absl::flat_hash_set* boundary) + : old_module_(old_module), + module_(absl::make_unique("extracted", config_)), + clone_context_(module_.get()), + builder_("entry_computation"), + boundary_(boundary) {} + + Status HandleParameter(const HloInstruction* parameter) override { + // Entry parameters need renumbering. + auto new_parameter = HloInstruction::CreateParameter( + parameter_number_++, parameter->shape(), parameter->name()); + clone_context_.MapInstruction(parameter, new_parameter.get()); + builder_.AddInstruction(std::move(new_parameter)); + return Status::OK(); + } + + Status DefaultAction(const HloInstruction* hlo) override { + // Replace instructions at the boundary with parameters, but leave constants + // untouched. + if (boundary_ != nullptr && boundary_->count(hlo) > 0) { + auto new_parameter = HloInstruction::CreateParameter( + parameter_number_, hlo->shape(), hlo->name()); + parameter_number_++; + clone_context_.MapInstruction(hlo, new_parameter.get()); + builder_.AddInstruction(std::move(new_parameter)); + return Status::OK(); + } + std::vector new_operands; + for (auto operand : hlo->operands()) { + new_operands.push_back(clone_context_.GetInstruction(operand)); + } + auto instruction = + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &clone_context_); + builder_.AddInstruction(std::move(instruction)); + return Status::OK(); + } + + Status FinishVisit(const HloInstruction* /*root*/) override { + module_->AddEntryComputation(builder_.Build()); + // Rename HLOs so that their name matches the original. By default, + // HLOs get new unique names when adding a new entry computation to + // a module. + for (auto computation : old_module_.MakeComputationPostOrder()) { + for (auto old_instruction : computation->MakeInstructionPostOrder()) { + if (auto new_instruction = + clone_context_.FindInstruction(old_instruction)) { + new_instruction->SetAndSanitizeName(old_instruction->name()); + } + } + } + return Status::OK(); + } + + HloModule* module() { return module_.get(); } + + std::unique_ptr ConsumeModule() { return std::move(module_); } + + private: + const HloModule& old_module_; + HloModuleConfig config_; + std::unique_ptr module_; + HloCloneContext clone_context_; + HloComputation::Builder builder_; + absl::flat_hash_set* boundary_; + int64 parameter_number_ = 0; +}; + +void ComputeBoundary(const HloInstruction* root, int64 limit, + absl::flat_hash_set* boundary) { + std::deque worklist; + absl::flat_hash_map visited; + worklist.push_back(root); + visited.emplace(root, 0); + while (!worklist.empty()) { + auto hlo = worklist.front(); + worklist.pop_front(); + int64 hops = visited[hlo]; + if (hops > limit) { + boundary->insert(hlo); + continue; + } + for (const HloInstruction* operand : hlo->operands()) { + if (visited.count(operand)) { + continue; + } + worklist.push_back(operand); + visited.emplace(operand, hops + 1); + } + } +} + +} // namespace + +std::unique_ptr ExtractModule(HloInstruction* instruction, + int64 height) { + absl::flat_hash_set boundary; + if (height != -1) { + ComputeBoundary(instruction, height, &boundary); + } + ExtractionVisitor visitor(*instruction->GetModule(), &boundary); + CHECK(instruction->Accept(&visitor).ok()); + + // The first pass may leave unused parameter instructions. Do another + // extraction pass to remove unused parameters. This is done because + // HloComputation does not allow removing parameters after the computation has + // been built. + ExtractionVisitor cleanup_visitor(*visitor.module(), /*boundary=*/nullptr); + TF_CHECK_OK(visitor.module()->entry_computation()->root_instruction()->Accept( + &cleanup_visitor)); + + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); + TF_CHECK_OK(verifier.Run(cleanup_visitor.module()).status()); + return cleanup_visitor.ConsumeModule(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/hlo_extractor.h b/tensorflow/compiler/xla/tools/hlo_extractor.h new file mode 100644 index 0000000000000000000000000000000000000000..bc13dc7e438fe0e64312746150af02df805e746a --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +// Creates a new HLO module rooted with an entry computation rooted at the given +// instruction. +// +// By default (height == -1), the new computation includes all transitive +// operands of `root`. If you specify a different height, the new computation +// will include all instructions <= `height` hops away from `root`. +// Instructions at the boundary are replaced by parameters. +std::unique_ptr ExtractModule(HloInstruction* instruction, + int64 height = -1); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ diff --git a/tensorflow/compiler/xla/tools/hlo_extractor_test.cc b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4beb099b330cadf4540944979f38681bae07103c --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = testing::opcode_matchers; + +using HloExtractorTest = HloTestBase; + +TEST_F(HloExtractorTest, ExtractTopLevel) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + param.0 = f32[4]{0} parameter(0) + negate = f32[4]{0} negate(f32[4]{0} param.0) + ROOT exp = f32[4]{0} exponential(f32[4]{0} negate) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp")); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Negate(op::Parameter(0)))); + } + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Parameter(0))); + } + + { + auto extracted_module = ExtractModule( + FindInstruction(hlo_module.get(), "negate"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Negate(op::Parameter(0))); + } +} + +TEST_F(HloExtractorTest, ExtractDag) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + param.0 = f32[4]{0} parameter(0) + tanh = f32[4]{0} tanh(f32[4]{0} param.0) + negate = f32[4]{0} negate(f32[4]{0} tanh) + exp = f32[4]{0} exponential(f32[4]{0} negate) + ROOT add = f32[4]{0} add(f32[4]{0} negate, f32[4]{0} exp) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp")); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Negate(op::Tanh(op::Parameter(0))))); + } + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Parameter(0), op::Parameter(1))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/1); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Negate(op::Parameter(0)), + op::Exp(op::Negate(op::Parameter(0))))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/2); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Negate(op::Tanh(op::Parameter(0))), + op::Exp(op::Negate(op::Tanh(op::Parameter(0)))))); + } +} + +TEST_F(HloExtractorTest, ExtractWithConstant) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + p = f32[4]{0} parameter(0) + tanh = f32[4]{0} tanh(p) + c = f32[4]{0} constant({1, 2, 3, 4}) + ROOT add = f32[4]{0} add(tanh, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_string)); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Parameter(0), op::Parameter(1))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/1); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Tanh(op::Parameter(0)), op::Constant())); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz.cc b/tensorflow/compiler/xla/tools/interactive_graphviz.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac865707f8697e0b94173a2a33e7be52a9564867 --- /dev/null +++ b/tensorflow/compiler/xla/tools/interactive_graphviz.cc @@ -0,0 +1,652 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A tool for interactively exploring graphviz dumps of HLO graphs. +// +// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a +// textual HLO string. +// +// Generated visualization is opened in a new default browser window using +// /usr/bin/sensible-browser. + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/util/command_line_flags.h" +#if defined(PLATFORM_GOOGLE) +#include "util/readline/readline.h" +#endif + +namespace xla { +namespace tools { +namespace { + +bool ReadLine(const char *prompt, string *line) { +#if defined(PLATFORM_GOOGLE) + return util::ReadLine(prompt, line); +#else + std::cout << prompt; + std::getline(std::cin, *line); + return std::cin.good(); +#endif +} + +// Command-line opts to this tool. See main() for descriptions of these +// fields. +struct Options { + string hlo_snapshot; + string hlo_proto; + string hlo_text; + string platform; + string browser; +}; + +const char* const kUsage = R"( +This tool lets you load an XLA dump and then interactively explore its graphical +representation. + +Most models are too large to visualize in their entirety using graphviz, but +it's still useful to be able to look at the nodes "near" a particular node of +interest. + +If you pass --platform, this tool will compile the HloModule for the given +platform. This means that if you acquired your proto from a binary running at a +particular CL, the HLO graph it ran isn't necessarily the same as the one shown +here, unless this program was built at the same CL (and our compiler is +deterministic :). + +Be patient when starting this program if you give it a large input; it has to +compile the whole thing. + +Usage: + + interactive_graphviz -- \ + --{hlo_snapshot,hlo_proto,hlo_text}=path/to/binary_proto + --platform={CUDA,CPU,...} +)"; + +// Unless an explicit width is specified, we will render a neighborhood of +// kDefaultWidth nodes around the requested instruction. +constexpr int64 kDefaultWidth = 2; + +// When printing all paths between two nodes, we print out only this many nodes +// by default, truncating the graph if there are more nodes than this in the +// all-paths set. +constexpr int64 kDefaultMaxNumNodesInAllPaths = 100; + +using absl::EqualsIgnoreCase; + +// A global control for whether backend configuration display is enabled. +bool show_backend_config = true; + +HloInstruction* FindInstruction(const HloModule& module, string node_name) { + if (absl::StartsWith(node_name, "%")) { + node_name.erase(node_name.begin()); + } + for (const auto& computation : module.computations()) { + auto instrs = computation->instructions(); + auto it = absl::c_find_if(instrs, [&](const HloInstruction* instr) { + // Try with and without "%" at the beginning of the node name. + return EqualsIgnoreCase(instr->name(), node_name) || + EqualsIgnoreCase(instr->name(), absl::StrCat("%", node_name)); + }); + if (it != instrs.end()) { + return *it; + } + } + return nullptr; +} + +HloComputation* FindComputation(const HloModule& module, + const string& comp_name) { + for (auto* computation : module.computations()) { + if (EqualsIgnoreCase(computation->name(), comp_name)) { + return computation; + } + } + return nullptr; +} + +// Print a help message describing the various available commands. +void DoHelpCommand() { + std::cout << R"(Commands: + [] + Renders a neighborhood of nodes around . If + is not provided, the default value is )" + << kDefaultWidth << R"(. + allpaths [] + Renders a subset of all paths from one instruction to the other. Either + order of nodes is accepted. Shows the nodes in the all-paths set on the + shortest paths; default is )" + << kDefaultMaxNumNodesInAllPaths << R"(. + + Renders all nodes in . + backend_config [on|off] + Controls whether backend operation configuration information is printed. + list [name|op_name|op_type] + Lists all instructions whose name, metadata op_name, or metadata op_type + contains as a substring. + list computations + Lists all computations in the module. + info + info + Prints information about or . + extract + Creates a new HLO module with as entry computation root. If + is specified, the new computation contains nodes up to + nodes above the root. + help + Prints this usage information.)" + << std::endl; +} + +// Turn metadata-printing on or off. +void DoBackendConfigCommand(const std::vector& tokens) { + if (tokens.size() == 2 && tokens[1] == "on") { + show_backend_config = true; + } else if (tokens.size() == 2 && tokens[1] == "off") { + show_backend_config = false; + } else if (tokens.size() != 1) { + std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)" + << std::endl; + } + std::cout << "Backend configuration display " + << (show_backend_config ? "ON" : "OFF") << std::endl; +} + +// List all computations in the module. +void DoListComputationsCommand(const HloModule& module, + const std::vector& tokens) { + if (tokens.size() > 2) { + std::cout << R"(Illegal syntax; "list computations" takes no arguments.)"; + return; + } + if (module.entry_computation() != nullptr) { + std::cout << "Entry computation:" << std::endl; + std::cout << " " << module.entry_computation()->name() << std::endl + << std::endl; + } + std::cout << "Subcomputations:" << std::endl; + std::vector names; + for (const auto& computation : module.computations()) { + if (computation == module.entry_computation()) { + continue; + } + std::cout << " " << computation->name() << std::endl; + } +} + +// List all instructions matching a pattern. +void DoListCommand(const HloModule& module, const std::vector& tokens) { + string pattern = ""; + string type = "name"; + if (tokens.size() == 2) { + pattern = tokens[1]; + } else if (tokens.size() == 3) { + type = tokens[1]; + pattern = tokens[2]; + } else { + std::cout << "Illegal list query syntax. Use " + << R"("list [name|op_name|op_type] pattern".)" << std::endl; + return; + } + + std::cout << "Query results:" << std::endl; + for (const auto& computation : module.computations()) { + for (const auto& instr : computation->instructions()) { + if ((type == "name" && instr->name().find(pattern) != string::npos) || + (type == "op_name" && + instr->metadata().op_name().find(pattern) != string::npos) || + (type == "op_type" && + instr->metadata().op_type().find(pattern) != string::npos)) { + std::cout << " " << instr->name(); + std::cout << ", op_name '" << instr->metadata().op_name() << "'"; + std::cout << ", op_type '" << instr->metadata().op_type() << "'"; + std::cout << std::endl; + } + } + } +} + +// Print info about an instruction or computation. +void DoInfoCommand(const HloModule& module, const std::vector& tokens) { + if (tokens.size() != 2) { + std::cerr << "Illegal info query syntax. Use " + << R"("info name".)"; + return; + } + string node_name = tokens[1]; + + const HloInstruction* instr = FindInstruction(module, node_name); + const HloComputation* comp = FindComputation(module, node_name); + if (!instr && !comp) { + std::cerr << "Couldn't find HloInstruction or HloComputation named " + << node_name << std::endl; + return; + } + + if (comp != nullptr) { + std::cout << "HloComputation " << comp->name() << std::endl; + if (comp->IsFusionComputation()) { + std::cout << " Fusion instruction: " << comp->FusionInstruction()->name() + << std::endl; + } + std::cout << " Parameters:" << std::endl; + for (const auto& param : comp->parameter_instructions()) { + std::cout << " " << param->name() << " (" + << ShapeUtil::HumanStringWithLayout(param->shape()) << ")" + << std::endl; + } + HloInstruction* root = comp->root_instruction(); + std::cout << " Root instruction: " << root->name() << " (" + << ShapeUtil::HumanStringWithLayout(root->shape()) << ")" + << std::endl; + + auto embedded_computations = comp->MakeEmbeddedComputationsList(); + std::cout << " " << embedded_computations.size() << " embedded computation" + << (embedded_computations.size() != 1 ? "s" : "") + << (!embedded_computations.empty() ? ":" : ".") << std::endl; + for (const HloComputation* c : embedded_computations) { + std::cout << " " << c->name() << std::endl; + } + + // Find which computations reference comp as an embedded computation. + std::vector users; + for (const HloComputation* c : module.computations()) { + if (absl::c_linear_search(c->MakeEmbeddedComputationsList(), comp)) { + users.push_back(c); + } + } + std::cout << " Used by " << users.size() << " computation" + << (users.size() != 1 ? "s" : "") << (!users.empty() ? ":" : "."); + for (const HloComputation* c : users) { + std::cout << " " << c->name() << std::endl; + } + } else { + std::cout << "HloInstruction " << instr->name() << std::endl; + std::cout << " Parent computation: " << instr->parent()->name() + << std::endl; + std::cout << " Opcode: " << HloOpcodeString(instr->opcode()) << std::endl; + std::cout << " Shape: " << ShapeUtil::HumanStringWithLayout(instr->shape()) + << std::endl; + std::cout << " Metadata:" << std::endl; + if (!instr->metadata().op_name().empty()) { + std::cout << " Name: " << instr->metadata().op_name() << std::endl; + } + if (!instr->metadata().op_type().empty()) { + std::cout << " Type: " << instr->metadata().op_type() << std::endl; + } + if (!instr->raw_backend_config_string().empty()) { + std::cout << " Backend configuration: " + << instr->raw_backend_config_string() << std::endl; + } + if (instr->opcode() == HloOpcode::kFusion) { + std::cout << " Fusion kind: " << xla::ToString(instr->fusion_kind()) + << std::endl; + std::cout << " Fusion computation: " + << instr->fused_instructions_computation()->name() << std::endl; + std::cout << " Fused computation root: " + << instr->fused_expression_root()->name() << std::endl; + } + std::cout << " Operands:" << std::endl; + for (HloInstruction* operand : instr->operands()) { + std::cout << " " << operand->name() << " (" + << ShapeUtil::HumanStringWithLayout(operand->shape()) << ")" + << std::endl; + } + std::cout << " Users:" << std::endl; + for (HloInstruction* user : instr->users()) { + std::cout << " " << user->name() << std::endl; + } + if (instr->parent()->root_instruction() == instr) { + std::cout << " Root instruction of " << instr->parent()->name() + << std::endl; + } + } +} + +void DoExtractCommand(const HloModule& module, + absl::Span tokens) { + if (tokens.size() > 3) { + std::cerr << R"(Illegal input. Enter e.g. "extract %fusion.1 2")" + << std::endl; + return; + } + + // Find the node with the given name. + string node_name = tokens[1]; + HloInstruction* instr = FindInstruction(module, node_name); + if (!instr) { + std::cerr << "Couldn't find HloInstruction named " << node_name << "." + << std::endl; + return; + } + + int64 height = -1; + if (tokens.size() == 3) { + if (!absl::SimpleAtoi(tokens[2], &height)) { + std::cerr << "Can't parse '" << tokens[2] << "' as an integer." + << std::endl; + return; + } + } + + auto extracted_module = ExtractModule(instr, height); + std::cout << extracted_module->ToString( + HloPrintOptions::ShortParsable().set_print_backend_config( + show_backend_config)) + << std::endl; +} + +// Checks if there is a use-def path from `from` to `to`. +bool ExistsPathFromTo(const HloInstruction* from, const HloInstruction* to) { + std::unordered_set visited; + std::vector to_visit = {from}; + while (!to_visit.empty()) { + auto* n = to_visit.back(); + if (n == to) { + return true; + } + to_visit.pop_back(); + visited.insert(n); + for (auto* user : n->users()) { + if (!visited.count(user)) { + to_visit.push_back(user); + } + } + } + return false; +} + +void DisplayGraphHandle(const Options &opts, const string& handle) { + std::cout << handle << std::endl; + + // If it is a url, try to open it up in the user's browser too. + if (absl::StartsWithIgnoreCase(handle, "http://") || + absl::StartsWithIgnoreCase(handle, "https://") || + absl::StartsWithIgnoreCase(handle, "file://")) { + const char* browser_bin = opts.browser.empty() ? "/usr/bin/sensible-browser" + : opts.browser.c_str(); + tensorflow::SubProcess p; + p.SetProgram(browser_bin, {browser_bin, handle}); + p.Start(); + } else if (handle.empty()) { + std::cerr << "Unable to render graph, perhaps due to graphviz server " + "timeout. Run with --logtostderr to see." + << std::endl; + } else { + std::cerr << "\nExpected a URL, but got strange graph result (dumped " + "above). If this isn't what you expected, maybe file a bug?" + << std::endl; + } +} + +void DoAllPathsCommand(const Options& opts, const HloModule& module, + const std::vector& tokens) { + if (tokens.size() > 4) { + std::cerr << R"(Illegal input. Enter e.g. "allpaths %add.4 %subtract.2" or +"allpaths add.4 subtract.2 42.)" + << std::endl; + return; + } + + int64 max_nodes = kDefaultMaxNumNodesInAllPaths; + if (tokens.size() == 4 && !absl::SimpleAtoi(tokens[3], &max_nodes)) { + std::cerr << "Can't parse '" << tokens[3] << "' as an integer." + << std::endl; + return; + } + + const HloInstruction* n1 = FindInstruction(module, tokens[1]); + if (!n1) { + std::cerr << "Couldn't find HloInstruction named " << tokens[1]; + return; + } + const HloInstruction* n2 = FindInstruction(module, tokens[2]); + if (!n2) { + std::cerr << "Couldn't find HloInstruction named " << tokens[2]; + return; + } + + // Is there a path from n1 to n2, or vice versa? + const HloInstruction* from; + const HloInstruction* to; + if (ExistsPathFromTo(n1, n2)) { + from = n1; + to = n2; + } else if (ExistsPathFromTo(n2, n1)) { + from = n2; + to = n1; + } else { + std::cerr << "No path from/to " << tokens[1] << " to/from " << tokens[2]; + return; + } + DisplayGraphHandle(opts, hlo_graph_dumper::DumpAllPathsFromTo( + *from, *to, max_nodes, /*show_backend_config=*/show_backend_config)); +} + +// Plot a given instruction neighborhood or computation with graphviz. +void DoPlotCommand(const Options& opts, const HloModule& module, + const std::vector& tokens) { + if (tokens.size() > 2) { + std::cerr << R"(Illegal input. Enter e.g. "%fusion.1 42" or "%fusion.1".)" + << std::endl; + return; + } + + string node_name = tokens[0]; + + // Find the node with the given name. + const HloInstruction* instr = FindInstruction(module, node_name); + const HloComputation* comp = FindComputation(module, node_name); + if (!instr && !comp) { + std::cerr << "Couldn't find HloInstruction or HloComputation named " + << node_name << "." << std::endl; + return; + } + + uint64 graph_width = kDefaultWidth; + if (tokens.size() == 2) { + if (comp) { + std::cerr << "Can only use graph-size parameter with instructions, but " + << node_name << " is a computation." << std::endl; + return; + } + if (!absl::SimpleAtoi(tokens[1], &graph_width)) { + std::cerr << "Can't parse '" << tokens[1] << "' as an integer." + << std::endl; + return; + } + } + + // Generate the graph and print the resulting string, which should be a + // graphviz url. + if (comp) { + DisplayGraphHandle(opts, hlo_graph_dumper::DumpGraph( + *comp, "", comp->parent()->config().debug_options(), nullptr, + /*show_backend_config=*/show_backend_config)); + } else { + DisplayGraphHandle(opts, hlo_graph_dumper::DumpNeighborhoodAround( + *instr, graph_width, /*show_backend_config=*/show_backend_config)); + } +} + +// Run the main event loop, reading user commands and processing them. +void InteractiveDumpGraphs(const Options& opts, const HloModule& module) { + // This is an interactive tool, but some may use `extract` in non-tty + // environment anyway. Give them a clean hlo dump. + if (isatty(fileno(stdin))) { + std::cout << "\n\nLoaded module " << module.name() << "." << std::endl; + DoHelpCommand(); + } + for (string line; ReadLine("\ncommand: ", &line);) { + if (line.empty()) { + std::cout << R"(Enter e.g. "fusion.1 3" or "add.8".)" << std::endl + << R"(Enter "help" for help; ^D, "quit", or "exit" to exit.)" + << std::endl; + continue; + } + std::vector tokens = absl::StrSplit(line, ' '); + if (tokens[0] == "quit" || tokens[0] == "exit") { + break; + } else if (tokens[0] == "help") { + DoHelpCommand(); + } else if (tokens[0] == "backend_config") { + DoBackendConfigCommand(tokens); + } else if (tokens[0] == "list") { + if (tokens.size() > 1 && tokens[1] == "computations") { + DoListComputationsCommand(module, tokens); + } else { + DoListCommand(module, tokens); + } + } else if (tokens[0] == "info") { + DoInfoCommand(module, tokens); + } else if (tokens[0] == "extract") { + DoExtractCommand(module, tokens); + } else if (tokens[0] == "allpaths") { + DoAllPathsCommand(opts, module, tokens); + } else { + DoPlotCommand(opts, module, tokens); + } + } +} + +void CheckFlags(const Options &opts) { + std::vector nonempty_proto_flags; + if (!opts.hlo_proto.empty()) { + nonempty_proto_flags.push_back("--hlo_proto"); + } + if (!opts.hlo_snapshot.empty()) { + nonempty_proto_flags.push_back("--hlo_snapshot"); + } + if (!opts.hlo_text.empty()) { + nonempty_proto_flags.push_back("--hlo_text"); + } + switch (nonempty_proto_flags.size()) { + case 1: + // We're good to go. + break; + case 0: + LOG(FATAL) << "Need one of the following options: " + << absl::StrJoin(nonempty_proto_flags, ", "); + default: + LOG(FATAL) << "Can only specify one of " + << absl::StrJoin(nonempty_proto_flags, ", "); + } +} + +void RealMain(const Options& opts) { + if (!isatty(fileno(stdin))) { + LOG(ERROR) << "\n\n*****************************************\n" + << "This is an interactive tool, but stdin is not a tty.\n" + << "*****************************************\n\n"; + } + + CheckFlags(opts); + + std::unique_ptr module; + if (!opts.hlo_snapshot.empty()) { + HloSnapshot snapshot; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), + opts.hlo_snapshot, &snapshot)) + << "Can't open, read, or parse HloSnapshot proto at " + << opts.hlo_snapshot; + auto config = + HloModule::CreateModuleConfigFromProto(snapshot.hlo().hlo_module(), + xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + module = HloModule::CreateFromProto(snapshot.hlo().hlo_module(), config) + .ValueOrDie(); + } else if (!opts.hlo_proto.empty()) { + module = HloRunner::ReadModuleFromBinaryProtoFile( + opts.hlo_proto, xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + } else if (!opts.hlo_text.empty()) { + module = HloRunner::ReadModuleFromHloTextFile( + opts.hlo_text, xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + } + + // If a platform was specified, compile the module for that platform. + if (!opts.platform.empty()) { + se::Platform* platform = + PlatformUtil::GetPlatform(opts.platform).ValueOrDie(); + LOG(INFO) << "Compiling module for " << platform->Name(); + + se::StreamExecutor* executor = + platform->ExecutorForDevice(/*ordinal=*/0).ValueOrDie(); + auto compiler = Compiler::GetForPlatform(platform).ValueOrDie(); + module = compiler + ->RunHloPasses(std::move(module), executor, + /*device_allocator=*/nullptr) + .ValueOrDie(); + auto executable = compiler + ->RunBackend(std::move(module), executor, + /*device_allocator=*/nullptr) + .ValueOrDie(); + InteractiveDumpGraphs(opts, executable->module()); + } else { + InteractiveDumpGraphs(opts, *module); + } +} + +} // namespace +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + xla::tools::Options opts; + opts.browser = "/usr/bin/sensible-browser"; + bool need_help = false; + const std::vector flag_list = { + tensorflow::Flag("hlo_snapshot", &opts.hlo_snapshot, + "HloSnapshot proto to interactively dump to graphviz"), + tensorflow::Flag("hlo_proto", &opts.hlo_proto, + "XLA hlo proto to interactively dump to graphviz"), + tensorflow::Flag("hlo_text", &opts.hlo_text, + "XLA hlo proto to interactively dump to graphviz"), + tensorflow::Flag("platform", &opts.platform, + "Platform to compile for: CPU, CUDA, etc"), + tensorflow::Flag("browser", &opts.browser, + "Path to web browser used to display produced graphs."), + tensorflow::Flag("help", &need_help, + "Prints this help message"), + }; + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc != 1 || !parse_ok || need_help) { + LOG(QFATAL) << usage; + } + xla::tools::RealMain(opts); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh b/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh new file mode 100755 index 0000000000000000000000000000000000000000..b3e43aa7da062547fb5f187b885e997fc44bbb65 --- /dev/null +++ b/tensorflow/compiler/xla/tools/interactive_graphviz_test.sh @@ -0,0 +1,19 @@ +#! /bin/bash +# /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ==============================================================================*/ + +# This is a placeholder for a compile-only test for intractive_graphviz tool. + +exit 0 diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ff2c3399928c0e6339304323c4f93e212933a340..21217c23f6561a509cb3e30bf3dc841f8dc5db87 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -73,14 +74,24 @@ namespace { // fields. struct Options { string fake_infeed_shape; - bool generate_fake_infeed = false; + string fake_outfeed_shape; + + // generate_fake_infeed == true is a safe default: If the model has 0 or 1 + // infeeds, then it will work like normal. If the model has more than one + // infeed, it will be an error, but that wouldn't have worked anyway if you + // hadn't passed generate_fake_infeed. + // + // Same for generate_fake_outfeed. + bool generate_fake_infeed = true; + bool generate_fake_outfeed = true; + bool use_fake_data = false; bool print_result = true; int num_runs = 1; }; -std::unique_ptr CompileExecutable(const HloSnapshot& module, - LocalClient* client) { +StatusOr> CompileExecutable( + const HloSnapshot& module, LocalClient* client) { XlaComputation computation(module.hlo().hlo_module()); std::vector argument_layouts; argument_layouts.reserve( @@ -91,9 +102,85 @@ std::unique_ptr CompileExecutable(const HloSnapshot& module, argument_layouts.push_back(Shape(param)); argument_layout_ptrs.push_back(&argument_layouts.back()); } - return client - ->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions()) - .ValueOrDie(); + return client->Compile(computation, argument_layout_ptrs, + ExecutableBuildOptions()); +} + +absl::optional GetXfeedShape(bool is_infeed, + const HloModuleProto& module, + const Options& opts) { + std::vector xfeed_instrs; + for (const auto& comp : module.computations()) { + for (const auto& instruction : comp.instructions()) { + if (instruction.opcode() == HloOpcodeString(is_infeed + ? HloOpcode::kInfeed + : HloOpcode::kOutfeed)) { + xfeed_instrs.push_back(instruction); + } + } + } + + auto log_xfeed_instrs = [&] { + for (const auto& infeed : xfeed_instrs) { + LOG(ERROR) << " " << ShapeUtil::HumanString(Shape(infeed.shape())) << " " + << infeed.name(); + } + }; + + auto find_instruction_from_id_or_die = [&](int64 id) { + for (const auto& comp : module.computations()) { + for (const auto& instruction : comp.instructions()) { + if (instruction.id() == id) { + return instruction; + } + } + } + LOG(FATAL) << "No instruction with id " << id; + }; + + absl::optional xfeed_shape; + string xfeed_name = is_infeed ? "infeed" : "outfeed"; + string fake_xfeed_shape = + is_infeed ? opts.fake_infeed_shape : opts.fake_outfeed_shape; + bool generate_fake_xfeed = + is_infeed ? opts.generate_fake_infeed : opts.generate_fake_outfeed; + if (!fake_xfeed_shape.empty()) { + xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie(); + } else if (generate_fake_xfeed) { + CHECK_LT(xfeed_instrs.size(), 2) + << "--generate_fake_" << xfeed_name + << " only works if the model has 0 or 1 " << xfeed_name << " ops."; + if (xfeed_instrs.empty()) { + LOG(INFO) << "Not generating fake " << xfeed_name + << " shape; model has no " << xfeed_name << "s."; + } else if (xfeed_instrs.size() == 1) { + // kInfeed instructions should have a shape (buffer, token). kOutfeed + // instructions should have operand 0 of shape `buffer`. We want to xfeed + // just `buffer`. + xfeed_shape = is_infeed + ? Shape(xfeed_instrs.front().shape()).tuple_shapes(0) + : Shape(find_instruction_from_id_or_die( + xfeed_instrs.front().operand_ids(0)) + .shape()); + LOG(INFO) << "Generating fake " << xfeed_name << " with inferred shape: " + << ShapeUtil::HumanString(*xfeed_shape); + } else { + LOG(ERROR) << "--generate_fake_" << xfeed_name + << " only works if the model has 0 or 1 " << xfeed_name + << " ops, but this model has " << xfeed_instrs.size() + << " of them:"; + log_xfeed_instrs(); + LOG(FATAL) << "Can't run model with --generate_fake_infeed."; + } + } else if (!xfeed_instrs.empty()) { + LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name + << " instruction(s), but neither --generate_fake_" << xfeed_name + << " nor --fake_" << xfeed_name + << "_shape was specified. Execution will likely hang."; + log_xfeed_instrs(); + } + + return xfeed_shape; } // Invokes the given computation passing arbitrary data for every (unbound) @@ -118,7 +205,12 @@ StatusOr ReplayComputation(const HloSnapshot& module, std::vector> global_data_arguments; std::vector argument_ptrs; if (opts.use_fake_data) { - global_data_arguments = MakeFakeArgumentsOrDie(computation, client); + // Run fake computations with debug options ignoring XLA_FLAGS. Users very + // likely want XLA_FLAGS only to apply to the "real" computation being run, + // not to the fake computations we use for generating arguments. + auto debug_opts = DefaultDebugOptionsIgnoringFlags(); + global_data_arguments = + MakeFakeArgumentsOrDie(computation, client, &debug_opts); for (const auto& data : global_data_arguments) { argument_ptrs.push_back( client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) @@ -137,55 +229,37 @@ StatusOr ReplayComputation(const HloSnapshot& module, } } - bool provide_infeed = false; - Shape infeed_shape; - if (!opts.fake_infeed_shape.empty()) { - StatusOr shape_status = - ShapeUtil::ParseShapeString(opts.fake_infeed_shape); - TF_CHECK_OK(shape_status.status()); - infeed_shape = std::move(shape_status).ValueOrDie(); - provide_infeed = true; - } else if (opts.generate_fake_infeed) { - for (const auto& comp : computation.proto().computations()) { - for (const auto& instruction : comp.instructions()) { - if (instruction.opcode() == HloOpcodeString(HloOpcode::kInfeed)) { - CHECK(!provide_infeed) - << "--generate_fake_infeed only works if the model has 0 or 1 " - "infeed ops, but this one has >= 2."; - provide_infeed = true; - infeed_shape = Shape(instruction.shape()); - LOG(INFO) << "Generating fake infeed shape for inferred shape: " - << ShapeUtil::HumanString(infeed_shape); - } - } - } - } - // We only instantiate the thread pool if the user has requested that a - // concurrent infeed occur via the fake_infeed_shape, or when - // --generate_fake_infeed is passed and there exists an infeed operation in - // the HloSnapshot. - absl::optional pool; - Literal data; - if (provide_infeed) { - data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); + if (absl::optional infeed_shape = GetXfeedShape( + /*is_infeed=*/true, computation.proto(), opts)) { + auto infeed_data = std::make_shared( + std::move(MakeFakeLiteral(*infeed_shape)).ValueOrDie()); + xla::gpu::GetOrCreateInfeedManager() + ->RegisterBeforeGetNextDestinationCallback([infeed_data, client] { + TF_CHECK_OK(client->TransferToInfeed(*infeed_data)); + }); } - auto transfer_infeed = [&data, client]() { - TF_CHECK_OK(client->TransferToInfeed(data)); - }; - if (provide_infeed) { - pool.emplace(tensorflow::Env::Default(), "infeed", - /*num_threads=*/1); - pool->Schedule([transfer_infeed]() { - // There may be several infeed buffers needed, however we don't know how - // many. If we proactively transfer too many infeed buffers, we may run - // out of memory. If we transfer too few infeed buffers, the program will - // hang. Therefore, we register a callback that is called when the infeed - // becomes empty, and in this callback we will transfer another fake - // infeed. - auto infeed_manager = xla::gpu::GetOrCreateInfeedManager(); - infeed_manager->RegisterOnEmptyCallback(transfer_infeed); - transfer_infeed(); - }); + + absl::optional outfeed_thread_pool; + if (absl::optional outfeed_shape = GetXfeedShape( + /*is_infeed=*/false, computation.proto(), opts)) { + // For each an outfeed that runs, enqueue a task that will consume it. We + // need a thread pool because the act of running an outfeed blocks on there + // being a destination available, and the act of making a destination + // available blocks on there being outfeed data available. + outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed", + /*num_threads=*/1); + auto consume_outfeed = [client, outfeed_shape] { + TF_CHECK_OK( + client->TransferFromOutfeedLocal(*outfeed_shape, /*device_ordinal=*/0) + .status()); + VLOG(1) << "Received outfeed data of shape " + << ShapeUtil::HumanStringWithLayout(*outfeed_shape); + }; + xla::gpu::GetOrCreateOutfeedManager() + ->RegisterBeforeGetNextDestinationCallback( + [consume_outfeed, &outfeed_thread_pool] { + outfeed_thread_pool->Schedule(consume_outfeed); + }); } // Do not attempt to run the executable if num_runs is less than 1. @@ -282,7 +356,7 @@ int RealMain(absl::Span args, const Options& opts) { // Compile all the modules in parallel. LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel."; - std::vector> executables; + std::vector>> executables; { // ThreadPool CHECK-fails if we give it 0 threads. tensorflow::thread::ThreadPool thread_pool( @@ -299,9 +373,16 @@ int RealMain(absl::Span args, const Options& opts) { LOG(INFO) << "Done compiling; now running the modules."; for (int64 i = 0; i < executables.size(); ++i) { - LocalExecutable* executable = executables[i].get(); + if (!executables[i].ok()) { + LOG(ERROR) << "Compilation failed: " << executables[i].status(); + exit_status = EXIT_FAILURE; + continue; + } + LocalExecutable* executable = executables[i].ValueOrDie().get(); + LOG(ERROR) << "Running iteration " << i; StatusOr result_status = ReplayComputation(snapshots[i], executable, client, opts); + LOG(ERROR) << "iteration complete."; if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", args[i], result_status.status().ToString().c_str()); @@ -346,9 +427,14 @@ int main(int argc, char** argv) { "Number of times to run each computation"), tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), + tensorflow::Flag("fake_outfeed_shape", &opts.fake_outfeed_shape, + "Shape of fake data to outfeed from computation"), tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed, - "Whether a fake infeed shape should be generated " - "derived from the computation"), + "Whether a fake infeed shape should be derived " + "from the computation"), + tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed, + "Whether a fake outfeed shape should be derived " + "from the computation"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index b645acb700b0f168112a40c9c72b4669435f717d..daf678f69017b9eb86cbc464a1f33b434021901d 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -41,6 +41,7 @@ using ::tensorflow::uint32; using ::tensorflow::uint64; using complex64 = std::complex; +using complex128 = std::complex; using ::Eigen::half; diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 68cab7387cf1576072f96878b50f07def6862d8b..34b73b5206fa20d6dff7567afd78fd89897c8c33 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -86,7 +86,7 @@ bool IsPermutation(absl::Span permutation, int64 rank) { CHECK_LT(index, rank); output[index] = 0; } - return std::find(output.begin(), output.end(), -1) == output.end(); + return !absl::c_linear_search(output, -1); } std::vector InversePermutation( diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 6722641e9d2c177440361e6f0d1f6c0804eb7cda..f2fd17dc99455a921bf875aad2a3661b4d456823 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -324,8 +324,7 @@ bool IsIdentityPermutation(absl::Span permutation); template int64 PositionInContainer(const Container& container, int64 value) { - return std::distance(container.begin(), - std::find(container.begin(), container.end(), value)); + return std::distance(container.begin(), absl::c_find(container, value)); } // Formats the container as a comma-separated string. StrAppend must support diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 51c73b3d17e4c32d9a8a14d3055ab56f02922af3..e001cc35f9fcea2783b3952e825838af6bbece72 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -137,25 +138,23 @@ bool HasPadding(const Window& window) { } bool HasSymmetricPadding(const Window& window) { - return std::all_of(window.dimensions().begin(), window.dimensions().end(), - [](const WindowDimension& dim) { - return dim.padding_low() == dim.padding_high(); - }); + return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) { + return dim.padding_low() == dim.padding_high(); + }); } bool HasSymmetricPadding(const PaddingConfig& padding_config) { - return std::all_of(padding_config.dimensions().begin(), - padding_config.dimensions().end(), - [](const PaddingConfig::PaddingConfigDimension& dim) { - return dim.edge_padding_low() == dim.edge_padding_high(); - }); + return absl::c_all_of(padding_config.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.edge_padding_low() == + dim.edge_padding_high(); + }); } bool HasNegativePadding(const Window& window) { - return std::any_of(window.dimensions().begin(), window.dimensions().end(), - [](const WindowDimension& dim) { - return dim.padding_low() < 0 || dim.padding_high() < 0; - }); + return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) { + return dim.padding_low() < 0 || dim.padding_high() < 0; + }); } bool HasBaseDilation(const Window& window) { @@ -190,10 +189,9 @@ bool AllOrNoneReversed(const Window& window) { return true; } bool reversed = window.dimensions()[0].window_reversal(); - return std::all_of(window.dimensions().begin(), window.dimensions().end(), - [&](const WindowDimension& dim) { - return dim.window_reversal() == reversed; - }); + return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) { + return dim.window_reversal() == reversed; + }); } bool HasDilation(const Window& window) { diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 1439f1bcc5cec39203a7cb4b1f8604e7349382c6..60adea5a4a242e5843b41927ba77c197e8fac444 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -1,30 +1,40 @@ """Wrapper around cc_proto_library used inside the XLA codebase.""" -load("//tensorflow/core:platform/default/build_config.bzl", - "cc_proto_library") -load("//tensorflow/core:platform/default/build_config_root.bzl", - "if_static") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "cc_proto_library", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "if_static", +) +load("//tensorflow:tensorflow.bzl", "if_cuda_is_configured") # xla_proto_library() is a convenience wrapper around cc_proto_library. -def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0, **kwargs): - if kwargs.get('use_grpc_plugin'): - kwargs['use_grpc_namespace'] = True - cc_proto_library(name=name, - srcs=srcs, - deps=deps, - cc_libs = if_static( - ["@protobuf_archive//:protobuf"], - otherwise=["@protobuf_archive//:protobuf_headers"], - ), - protoc="@protobuf_archive//:protoc", - testonly=testonly, - visibility=visibility, - **kwargs) +def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly = 0, **kwargs): + if kwargs.get("use_grpc_plugin"): + kwargs["use_grpc_namespace"] = True + cc_proto_library( + name = name, + srcs = srcs, + deps = deps, + cc_libs = if_static( + ["@protobuf_archive//:protobuf"], + otherwise = ["@protobuf_archive//:protobuf_headers"], + ), + protoc = "@protobuf_archive//:protoc", + testonly = testonly, + visibility = visibility, + **kwargs + ) def xla_py_grpc_library(**kwargs): - # Note: we don't currently define any special targets for Python GRPC in OSS. - _ignore = kwargs - pass - + # Note: we don't currently define any special targets for Python GRPC in OSS. + _ignore = kwargs + pass ORC_JIT_MEMORY_MAPPER_TARGETS = [] + +# We link the GPU plugin into the XLA Python extension if CUDA is enabled. +def xla_python_default_plugins(): + return if_cuda_is_configured(["//tensorflow/compiler/xla/service:gpu_plugin"]) diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f745fb850655edaba8c95ba0cd3af3cc765b99e6..92834dbb02cdcd6383ceec3ffd079834b163ee6a 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -100,6 +100,14 @@ message DebugOptions { // names as specified by the HloPassInterface::name() method. repeated string xla_disable_hlo_passes = 30; + // Disables all HLO passes. Notes that some passes are necessary for + // correctness and the invariants that must be satisfied by "fully optimized" + // HLO are different for different devices and may change over time. The only + // "guarantee", such as it is, is that if you compile XLA and dump the + // optimized HLO for some graph, you should be able to run it again on the + // same device with the same build of XLA. + bool xla_disable_all_hlo_passes = 104; + // Numerical optimization level for the XLA compiler backend; the specific // interpretation of this value is left to the backends. int32 xla_backend_optimization_level = 31; @@ -193,7 +201,11 @@ message DebugOptions { // - Assuming that operations never produce or consume NaN or +/- Inf. // - Assuming that +0 and -0 are indistinguishable. bool xla_cpu_enable_fast_math = 99; - bool xla_gpu_enable_fast_math = 100; + + // When true we lower the Minimum and Maximum hlos in the GPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag + // this is true we don't propagate NaNs through Min and Max. + bool xla_gpu_enable_fast_min_max = 100; // Crashes the program when any kind of verification fails, instead of just // logging the failures. One example is cross checking of convolution results @@ -209,6 +221,21 @@ message DebugOptions { // the host that run models in parallel across multiple devices. int32 xla_force_host_platform_device_count = 102; + // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). + bool xla_gpu_disable_ptxas_optimizations = 103; + + // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) + bool xla_hlo_dump_as_html = 105; + + // Enable fast math with eigen in the HLO evaluator. + bool xla_hlo_evaluator_use_fast_path = 106; + + // Temporary option to allow support for both the R1 and the scalar index + // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing. + bool xla_allow_scalar_index_dynamic_ops = 107; + + // Next id: 108 + // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; @@ -238,6 +265,10 @@ message ExecutionOptions { // computation on. The computation will be partitioned across these devices. // If not provided, the default device will be chosen. repeated DeviceHandle device_handles = 5; + + // Number of replicas of the computation to run. If zero, uses the default + // number of replicas for the XLA service. + int32 num_replicas = 6; } message GetDeviceHandlesRequest { @@ -382,7 +413,7 @@ message WaitForExecutionResponse { message ComputeConstantGraphRequest { HloModuleProto computation = 1; - Layout output_layout = 2; + LayoutProto output_layout = 2; } message ComputeConstantResponse { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 85ec83437a10d973687a7fb84285c2e2541a53c7..a64e2f5df5cacca05e83f31c941c57abd5ccf4de 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -56,6 +56,7 @@ enum PrimitiveType { // Complex values of fixed width. C64 = 15; // Paired F32 (real, imag), as in std::complex. + C128 = 18; // Paired F64 (real, imag), as in std::complex. // A tuple is a polymorphic sequence; e.g. a shape that holds different // sub-shapes. They are used for things like returning multiple values from a @@ -75,7 +76,7 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 18 + // Next = 19 } // Describes the padding configuration for Pad operation. The padding amount on @@ -100,6 +101,8 @@ message PaddingConfig { // A format specifies the method used by a layout to store an array in memory. enum Format { + // TODO(b/120869032): Rename this to FORMAT_NONE or something else which + // better corresponds to its meaning. INVALID_FORMAT = 0; // The default layout, with exactly one storage location per element. DENSE = 1; @@ -109,8 +112,9 @@ enum Format { } // Describes a tile used in tiling-based layout. Refer to -// g3doc/layout_with_tiling.md for details about tiling-based layout. -message Tile { +// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// details about tiling-based layout. +message TileProto { // Number of elements in each dimension of the tile. It's ordered from the // most major dimension of the tile to the most minor dimension of the tile. // The dimensions correspond to a suffix of the dimensions of the shape being @@ -128,7 +132,7 @@ message Tile { // See the XLA documentation for more information on shapes and layouts. // // LINT.IfChange -message Layout { +message LayoutProto { // The method used to store the data in memory. The format determines which of // the other fields are used by the layout. Format format = 4; @@ -153,7 +157,7 @@ message Layout { // // TODO(b/119839262): implement tiling in each backend or add Unimplemented // error. - repeated Tile tiles = 6; + repeated TileProto tiles = 6; // Bit size of each element. If the size is bigger than what the element // type requires, the value is stored in the least significant @@ -185,18 +189,27 @@ message ShapeProto { // The element type for this shape. PrimitiveType element_type = 2; - // The size (number of elements) for each dimension. - // In XLA, dimensions are numbered from 0 to N-1 for an - // N-dimensional array. The first element of 'dimensions' is the size of - // dimension 0, the second element is the size of dimension 1, and so forth. - // Empty list indicates a scalar. + // The size (number of elements) for each dimension, or an upper bound on the + // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 + // to N-1 for an N-dimensional array. The first element of 'dimensions' is the + // size of dimension 0, the second element is the size of dimension 1, and so + // forth. Empty list indicates a scalar. + // + // If the respective element in 'is_dimension_dynamic' is true then the value + // in this field represents an upper bound on the size of the dimension. repeated int64 dimensions = 3; // For tuples only, the shapes of constitutent shapes in the tuple sequence. repeated ShapeProto tuple_shapes = 4; // The layout used to back this shape. - Layout layout = 5; + LayoutProto layout = 5; + + // For arrays, this indicates whether or not each dimension is + // dynamically-sized. The number of elements in this repeated field should be + // zero (indicating that no dimensions are dynamic) or equal to the number of + // elements in the 'dimensions' field. + repeated bool is_dynamic_dimension = 6; // Important: if any field is added, be sure to modify ShapeUtil::Equal(), // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for @@ -355,6 +368,7 @@ message LiteralProto { repeated float f32s = 8; repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. + repeated double c128s = 18; // Stored as interleaved real, imag doubles. repeated LiteralProto tuple_literals = 10; // The F16s, BF16s, U16s and S16s are encoded in little endian byte order bytes f16s = 11; @@ -362,7 +376,7 @@ message LiteralProto { bytes u16s = 16; bytes s16s = 17; repeated int64 sparse_indices = 14; - // Next = 18 + // Next = 19 } message WindowDimension { diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 67f475846e5f16060c1080759b0acb4216c4e72b..dc02fd272fd8700c7f8fa64adf7ab57c88bab706 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -11,20 +11,15 @@ cc_library( name = "xrt_state_ops", hdrs = ["xrt_state_ops.h"], deps = [ + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt:xrt_utils", "//tensorflow/core:core_cpu_internal", @@ -55,6 +50,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt:xrt_utils", "//tensorflow/core:core_cpu_internal", @@ -62,7 +58,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/stream_executor:stream_executor_headers_lib", + "//tensorflow/stream_executor:stream_executor_headers", "@com_google_absl//absl/strings", ], alwayslink = 1, diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 2ccdf0f02d840600d5e0649c4805e3672d4a1286..2ee1a6cd1aebcdbd65892b33e5044489070ab5c4 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -215,11 +215,6 @@ XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default; void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XRTReleaseCompilationRefOp::Compute"; - const Tensor& key_tensor = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(key_tensor.shape()), - errors::Internal("computation key should be a string scalar")); - int64 uid = key_tensor.scalar()(); - ResourceMgr* rm; OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm)); @@ -230,9 +225,13 @@ void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { kXRTCompilationCacheResourceName, &cache)); core::ScopedUnref cache_unref(cache); - OP_REQUIRES_OK(ctx, cache->Release(uid)); - - VLOG(2) << "Released computation handle " << uid; + const Tensor& keys_tensor = ctx->input(0); + auto flat_keys = keys_tensor.flat(); + for (int64 i = 0; i < flat_keys.size(); ++i) { + int64 key = flat_keys(i); + OP_REQUIRES_OK(ctx, cache->Release(key)); + VLOG(2) << "Released computation handle " << key; + } } } // namespace diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 8c6191ddc06ea7d85f5fd21a7d4058c669ffdeb2..116c193cab65410a5a7c3058f98cc2be2cbe9e67 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -229,13 +230,53 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { shaped_buffer, device_ref.backend(), device_ref.device_ordinal(), &output_tuple)); - Tensor* output_tensor; - TF_RETURN_IF_ERROR( - context->allocate_output(0, TensorShape({}), &output_tensor)); - int64 key; - TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); - output_tensor->scalar()() = key; - + // The ScopedShapedBuffer returned by the executable Run() API, in case of + // input/output buffer aliasing, might have holes in it, which need to be + // filled using the proper input tuples buffers which are the source of + // aliasing. + const xla::HloInputOutputAliasConfig& input_output_alias = + executable->executable()->module().input_output_alias_config(); + auto alias_function = + [&](const xla::ShapeIndex& output_index, + const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { + TF_RET_CHECK(alias.parameter_number < input_tuples.size()); + return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias + ? output_tuple->AliasBufferFrom( + *input_tuples[alias.parameter_number], + alias.parameter_index, output_index) + : Status::OK(); + }; + TF_RETURN_IF_ERROR(input_output_alias.ForEachAliasWithStatus(alias_function)); + + if (config_proto.return_exploded_tuple() && + output_tuple->on_device_shape().IsTuple()) { + int64 tuple_element_count = + xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape()); + Tensor* output_tensor; + TF_RETURN_IF_ERROR(context->allocate_output( + 0, TensorShape({tuple_element_count}), &output_tensor)); + + for (int64 i = 0; i < tuple_element_count; ++i) { + xla::ShapeIndex shape_index; + shape_index.push_back(i); + + XRTTupleAllocation* suballocation; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + output_tuple, shape_index, &suballocation, + /*alias_parent_allocation=*/false)); + int64 key; + TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key)); + output_tensor->vec()(i) = key; + } + output_tuple->Unref(); + } else { + Tensor* output_tensor; + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({}), &output_tensor)); + int64 key; + TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); + output_tensor->scalar()() = key; + } return Status::OK(); } diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index ffea592491d43788b876a51866dc8a6611e8c734..6a7f10652533920ba3fa48fba1d5161f7c4d4530 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -37,6 +37,17 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocate") .HostMemory("handle"), XRTAllocateOp); +REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") + .Device(DEVICE_XLA_GPU) + .HostMemory("inputs") + .HostMemory("handle"), + XRTAllocateFromTensorOp); +REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") + .Device(DEVICE_XLA_CPU) + .HostMemory("inputs") + .HostMemory("handle"), + XRTAllocateFromTensorOp); + REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") .Device(DEVICE_XLA_GPU) .HostMemory("base_handle") @@ -87,6 +98,19 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") .HostMemory("literal"), XRTReadLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_GPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_CPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); + REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") .Device(DEVICE_XLA_GPU) .HostMemory("handle") @@ -107,4 +131,9 @@ REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") .HostMemory("handle"), XRTReleaseAllocationOp); +REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_GPU), + XRTReleaseAllAllocationsOp); +REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_CPU), + XRTReleaseAllAllocationsOp); + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 54b06558adcd8ef1f8f1bee52d210d558801afea..e2c223b3dbb2311d0f42e1a36e316fd9d5f66040 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -19,10 +19,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ #define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ +#include #include #include +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -30,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/compiler/xrt/xrt_device.h" #include "tensorflow/compiler/xrt/xrt_state.h" +#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" @@ -183,9 +188,7 @@ class XRTAllocateOp : public OpKernel { // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class DeviceAccessor::ScopedRef device_ref; - OP_REQUIRES_OK(ctx, - DeviceAccessor::InitScopedRef( - ctx, allocation_proto.device_ordinal(), &device_ref)); + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); XRTTupleAllocation* allocation; OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( @@ -202,6 +205,109 @@ class XRTAllocateOp : public OpKernel { } }; +// Op that allocates memory for a tensor (with optional layout) and transfers it +// to the device, returning an allocation handle. +template +class XRTAllocateFromTensorOp : public OpKernel { + public: + explicit XRTAllocateFromTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + bool make_tuple = false; + OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple)); + if (ctx->HasAttr("layouts")) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major_)); + } + OP_REQUIRES( + ctx, tf_shapes_.size() == dtypes_.size(), + errors::InvalidArgument("shapes and dtypes must be the same length")); + std::vector xla_shapes; + for (int i = 0; i < tf_shapes_.size(); i++) { + xla::Shape xla_shape; + OP_REQUIRES_OK( + ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape)); + xla_shapes.push_back(xla_shape); + } + if (xla_shapes.size() > 1 || make_tuple) { + shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes); + } else { + shape_.Swap(&xla_shapes.front()); + } + if (!minor_to_major_.empty()) { + xla::Shape shape_with_layouts; + OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major_, + /*layout_func=*/nullptr, + &shape_with_layouts)); + shape_.Swap(&shape_with_layouts); + } + } + + ~XRTAllocateFromTensorOp() override = default; + XRTAllocateFromTensorOp(const XRTAllocateFromTensorOp&) = delete; + XRTAllocateFromTensorOp& operator=(const XRTAllocateFromTensorOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTAllocateFromTensorOp::Compute"; + + OpInputList values; + OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values)); + OP_REQUIRES(ctx, values.size() == tf_shapes_.size(), + errors::InvalidArgument( + "Wrong number of inputs to XRTAllocateFromTensor: ", + values.size(), " vs. ", tf_shapes_.size())); + + std::vector tensors_data; + for (size_t i = 0; i < values.size(); ++i) { + const Tensor& input_tensor = values[i]; + OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i], + errors::InvalidArgument( + "Input tensor type and input dtype do not match")); + // We allow the requested on-device shape to differ from the shape of the + // input tensor, as long as they have the same number of elements. + OP_REQUIRES( + ctx, + input_tensor.shape().num_elements() == tf_shapes_[i].num_elements(), + errors::InvalidArgument( + "Input tensor must have the number of elements specified " + "in the matching input shape: ", + input_tensor.shape().num_elements(), " vs. ", + tf_shapes_[i].num_elements(), " at index ", i)); + tensors_data.push_back( + static_cast(DMAHelper::base(&input_tensor))); + } + // Use the buffer straight out of the input tensors to create the literal. + xla::BorrowingLiteral literal = + shape_.IsTuple() ? xla::BorrowingLiteral(tensors_data, shape_) + : xla::BorrowingLiteral(tensors_data.front(), shape_); + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + class DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( + literal, device_ref.backend(), + device_ref.device_ordinal(), &allocation)); + + // Intern takes ownership of our reference to allocation. + int64 key; + OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key)); + + Tensor output(DT_INT64, TensorShape({})); + output.scalar()() = key; + ctx->set_output(0, output); + } + + private: + std::vector tf_shapes_; + DataTypeVector dtypes_; + std::vector minor_to_major_; + xla::Shape shape_; +}; + // Op that takes a tuple handle input and returns a handle to a sub-tuple of the // input. template @@ -393,6 +499,56 @@ class XRTReadLiteralOp : public OpKernel { } }; +// Op that writes a new literal value into device-resident memory. +template +class XRTWriteLiteralOp : public OpKernel { + public: + explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~XRTWriteLiteralOp() override = default; + XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete; + XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTWriteLiteralOp::Compute"; + + const Tensor& handle_tensor = ctx->input(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), + errors::Internal("computation input should be an int64 scalar")); + int64 allocation_handle = handle_tensor.scalar()(); + + const Tensor& literal_info = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()), + errors::Internal("literal input should be a string scalar")); + xla::LiteralProto literal_proto; + OP_REQUIRES(ctx, + literal_proto.ParseFromString(literal_info.scalar()()), + errors::InvalidArgument( + "Unable to parse allocation input to LiteralProto")); + xla::Literal literal; + OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal)); + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK( + ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); + core::ScopedUnref allocation_unref(allocation); + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + typename DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( + ctx, allocation->device_ordinal(), &device_ref)); + OP_REQUIRES_OK(ctx, + allocation->WriteLiteral(device_ref.backend(), literal)); + + Tensor output(DT_INT64, TensorShape({})); + output.scalar()() = allocation_handle; + ctx->set_output(0, output); + } +}; + // Op that discards a handle to device memory. template class XRTReleaseAllocationOp : public OpKernel { @@ -405,17 +561,37 @@ class XRTReleaseAllocationOp : public OpKernel { void Compute(OpKernelContext* ctx) override { VLOG(1) << "XRTReleaseAllocationOp::Compute"; - const Tensor& allocation_handle = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_handle.shape()), - errors::Internal("handle input should be an int64 scalar")); - int64 key = allocation_handle.scalar()(); - ResourceMgr* rm; OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); - OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(rm, key)); + const Tensor& allocation_handle = ctx->input(0); + auto flat_keys = allocation_handle.flat(); + for (int64 i = 0; i < flat_keys.size(); ++i) { + int64 key = flat_keys(i); + OP_REQUIRES_OK(ctx, + XRTTupleAllocation::DeleteFromResourceManager(rm, key)); + VLOG(2) << "Released allocation handle " << key; + } + } +}; + +// Op that discards a handle to device memory. +template +class XRTReleaseAllAllocationsOp : public OpKernel { + public: + explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + ~XRTReleaseAllAllocationsOp() override = default; + XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete; + XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) = + delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTReleaseAllAllocationsOp::Compute"; - VLOG(2) << "Released allocation handle " << key; + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + OP_REQUIRES_OK(ctx, XRTTupleAllocation::ReleaseAllAllocations(rm)); } }; diff --git a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc index 7b3b50c69559f6003a108fdf6a1325dbdbaa80a6..9dd964e5467cd855d67764a512e95a6a18f482e1 100644 --- a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc @@ -44,10 +44,10 @@ REGISTER_OP("XRTReleaseCompilationHandle") .SetShapeFn(tensorflow::shape_inference::NoOutputs) .Doc( R"( -Discards a computation from the compilation cache. The handle cannot be -subsequently used. +Discards one or more computation handles from the compilation cache. +The handle(s) cannot be subsequently used. -'handle' is an id returned from a XRTCompile Op. +'handle' is an ID (or vector of IDs) returned from a XRTCompile Op. )"); } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 07d025ce343f229097b557d33ad41bf9612b0696..2e743fec4963a52ee1abf64525f26e3d89479670 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -26,12 +26,41 @@ REGISTER_OP("XRTAllocate") .SetShapeFn(tensorflow::shape_inference::ScalarShape) .Doc( R"( -Reads a literal proto and transfers it to TPU device memory. +Reads a literal proto and transfers it to device memory. -'allocation' is a serialized xrt::TPUAllocation proto. +'allocation' is a serialized xrt::XLAAllocation proto. 'handle' is an id that can be used in other ops to refer to the allocation. )"); +REGISTER_OP("XRTAllocateFromTensor") + .Input("inputs: dtypes") + .Output("handle: int64") + .Attr("dtypes: list(type)") + .Attr("shapes: list(shape)") + .Attr("layouts: list(int) = []") + .Attr("make_tuple: bool = false") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc( + R"( +Reads a list of tensors with optional layouts, and transfers it to device +memory. + +inputs: The tensors holding the input data. +shapes: The shapes which the tensors should have on device. The i-th shape +corresponds to the i-th input. The shapes, together with the (optional) +layouts, helps creating the fully qualified shape of the data on the device. +The shapes can differ from the corresponding input one, as long as the total +number of elements matches. In other words, it is possible to feed an input +tensor with shape {8} and have a corresponding shape {2,2,2}. +layouts: A vector holding the requested layout in minor-to-major sequence. +If empty, the default layout wil be used. +For a tuple, the layouts vector holds a linearized minor-to-major numbers +for all the tuple leaves, in the order they appear within the tuple. +The elements within the layouts sequence corresponding to a given tuple +subshape can be set to -1, to leave such subshape to the default shape. +handle: An id that can be used in other ops to refer to the allocation. +)"); + REGISTER_OP("XRTSubTuple") .Input("base_handle: int64") .Input("shape_index: int32") @@ -95,6 +124,20 @@ Copies an allocated tuple from device memory and returns it as a literal. 'literal' is a serialized xla::LiteralProto proto. )"); +REGISTER_OP("XRTWriteLiteral") + .Input("handle: int64") + .Input("literal: string") + .Output("output_handle: int64") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc( + R"( +Copies the input literal into the device memory pointed to by handle. +Returns the handle itself. + +'handle' is the id returned from the Op that produced the on-device allocation. +'literal' is a serialized xla::LiteralProto proto to be written to device memory. +)"); + REGISTER_OP("XRTReadLiteralAndRelease") .Input("handle: int64") .Output("literal: string") @@ -113,10 +156,18 @@ REGISTER_OP("XRTReleaseAllocationHandle") .SetShapeFn(tensorflow::shape_inference::NoOutputs) .Doc( R"( -Discards an allocation from device memory. The handle cannot be subsequently +Discards one or more device memory handles. The handle(s) cannot be subsequently used. -'handle' is the id returned from the Op that produced the on-device allocation. +'handle' is the ID (or a vector of IDs) returned from the Op that produced the +on-device allocation. +)"); + +REGISTER_OP("XRTReleaseAllAllocations") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc( + R"( +Discards all the XRT allocations. All the client held handles will be invalid. )"); } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index be44a3474acdeb9905c1d21b932fa0dd10b5a212..3a19327e5b5d8072fbecdbe10e9959c8491780eb 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -24,6 +24,7 @@ cc_library( "//tensorflow/cc:client_session", "//tensorflow/cc:ops", "//tensorflow/cc:scope", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index b9262c1843a7ae48af49acbef5ba4ef58ec0f050..1111f8240512e81c10a42a28c09f5b0a94daf1ee 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -53,6 +55,14 @@ string DeviceFromFlag() { return absl::StrCat("/device:", xla_test_device, ":0"); } +std::vector GetAttrLayout(absl::Span minor_to_mayor) { + std::vector layout; + for (auto dim : minor_to_mayor) { + layout.push_back(static_cast(dim)); + } + return layout; +} + xla::LiteralProto TwoElementTuple() { auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); @@ -96,14 +106,21 @@ xla::LiteralProto FloatMatrix( return array.ToProto(); } +xla::Literal ReadOutputLiteral(const std::vector& outputs, size_t idx) { + xla::LiteralProto response; + CHECK(response.ParseFromString(outputs[idx].scalar()())); + return xla::Literal::CreateFromProto(response).ValueOrDie(); +} + bool CompareLiteralProtos(const xla::LiteralProto& a, const xla::LiteralProto& b) { auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie(); auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = l_a == l_b; if (!equal) { - LOG(INFO) << "LiteralProtos don't match " << a.DebugString() - << " != " << b.DebugString(); + LOG(INFO) << "LiteralProtos don't match:\n" + << a.DebugString() << "\n!=\n" + << b.DebugString(); } return equal; } @@ -113,8 +130,19 @@ bool CompareLiteralToLiteralProto(const xla::Literal& a, auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = a == l_b; if (!equal) { - LOG(INFO) << "Literal and LiteralProto don't match " - << a.ToProto().DebugString() << " != " << b.DebugString(); + LOG(INFO) << "Literal and LiteralProto don't match:\n" + << a.ToProto().DebugString() << "\n!=\n" + << b.DebugString(); + } + return equal; +} + +bool CompareLiterals(const xla::Literal& a, const xla::Literal& b) { + bool equal = a == b; + if (!equal) { + LOG(INFO) << "Literals don't match:\n" + << a.ToProto().DebugString() << "\n!=\n" + << b.ToProto().DebugString(); } return equal; } @@ -175,6 +203,18 @@ xla::XlaComputation AddAndTuple() { return builder.Build().ValueOrDie(); } +xla::XlaComputation AddAndSubTuple() { + xla::XlaBuilder builder("AddAndSubTuple"); + auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P0"); + auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P1"); + auto sum = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + xla::Tuple(&builder, {sum, sub}); + return builder.Build().ValueOrDie(); +} + void StoreComputationSnapshot(const xla::XlaComputation& computation, xla::HloSnapshot* dst) { auto snapshot = computation.Snapshot().ValueOrDie(); @@ -203,9 +243,295 @@ xla::ProgramShape XlaCompiledProgramShape( ->ComputeProgramShape(); } +TEST(RawApiTest, AllocFromTensor) { + xla::Literal literal = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + Tensor tensor; + TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + std::vector layout = + GetAttrLayout(literal.shape().layout().minor_to_major()); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout); + auto handle = + ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); +} + +TEST(RawApiTest, AllocFromTensorTuple) { + xla::Literal literal0 = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + xla::Literal literal1 = + xla::LiteralUtil::CreateR2({{14.0f, -5.0f}, {16.0f, 17.0f}}); + xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1}); + Tensor tensor0; + TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0)); + Tensor tensor1; + TF_ASSERT_OK(LiteralToHostTensor(literal1, DT_FLOAT, &tensor1)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + std::vector layout = GetShapeLayoutVector(literal.shape()).ValueOrDie(); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout); + auto handle = ops::XRTAllocateFromTensor(root, {tensor0, tensor1}, + {tensor0.shape(), tensor1.shape()}, + alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); +} + +TEST(RawApiTest, AllocFromTensorTupleSingle) { + xla::Literal literal0 = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0}); + Tensor tensor0; + TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + std::vector layout = GetShapeLayoutVector(literal.shape()).ValueOrDie(); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout).MakeTuple(true); + auto handle = ops::XRTAllocateFromTensor(root, {tensor0}, {tensor0.shape()}, + alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); +} + +TEST(RawApiTest, AllocFromTensorRelayout) { + xla::Literal literal = + xla::LiteralUtil::CreateR2({{4.0f, 5.0f}, {6.0f, 7.0f}}); + Tensor tensor; + TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor)); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + // Use inverse array layout with the tensor data above. + std::vector layout({0, 1}); + ops::XRTAllocateFromTensor::Attrs alloc_attrs = + ops::XRTAllocateFromTensor::Layouts(layout); + auto handle = + ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs); + auto read_back = ops::XRTReadLiteralAndRelease(root, handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + // We have sent literal's data (in array layout) with a attribute layout + // {0,1}, so the expected literal read from device needs to be changed + // accordingly. + xla::Literal expected_literal = + xla::LiteralUtil::CreateR2({{4.0f, 6.0f}, {5.0f, 7.0f}}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected_literal, response)); +} + +TEST(RawApiTest, AllocAndRewrite) { + xrt::XLAAllocation alloc; + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + auto read_back = ops::XRTReadLiteral(root, handle); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 allocation_handle = outputs[1].scalar()(); + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); + outputs.clear(); + + xla::LiteralProto new_literal = + xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto(); + auto new_value = ops::Const(root.WithDevice("/device:CPU:0"), + new_literal.SerializeAsString()); + auto write_op = + ops::XRTWriteLiteral(root, Input(allocation_handle), new_value); + TF_ASSERT_OK(root.status()); + TF_EXPECT_OK(session.Run({write_op}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(allocation_handle, outputs[0].scalar()()); + outputs.clear(); + + auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run({read_after_write}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto new_response; + EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); + + Tensor release_tensor(DT_INT64, TensorShape({1})); + release_tensor.flat()(0) = allocation_handle; + + auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + +TEST(RawApiTest, AllocReleaseMany) { + xrt::XLAAllocation alloc1; + *alloc1.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + xrt::XLAAllocation alloc2; + *alloc2.mutable_value() = + xla::LiteralUtil::CreateR2({{6, 7}, {4, 5}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value1 = + ops::Const(root.WithDevice("/device:CPU:0"), alloc1.SerializeAsString()); + auto value2 = + ops::Const(root.WithDevice("/device:CPU:0"), alloc2.SerializeAsString()); + auto handle1 = ops::XRTAllocate(root, value1); + auto handle2 = ops::XRTAllocate(root, value2); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 allocation_handle1 = outputs[0].scalar()(); + int64 allocation_handle2 = outputs[1].scalar()(); + + Tensor release_tensor(DT_INT64, TensorShape({2})); + release_tensor.flat()(0) = allocation_handle1; + release_tensor.flat()(1) = allocation_handle2; + + auto release = ops::XRTReleaseAllocationHandle(root, release_tensor); + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + +TEST(RawApiTest, CompileAndReleaseMany) { + xrt::XLAComputation c1; + auto config1 = c1.mutable_config(); + auto shapes1 = config1->mutable_program_shape(); + *shapes1->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes1->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes1->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + StoreComputationSnapshot(AddAndScale(), c1.mutable_hlo_snapshot()); + + xrt::XLAComputation c2; + auto config2 = c2.mutable_config(); + auto shapes2 = config2->mutable_program_shape(); + *shapes2->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes2->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes2->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) + .ToProto(); + StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(false); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation1 = + ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString()); + auto c_handle1 = ops::XRTCompile(root, computation1); + auto computation2 = + ops::Const(root.WithDevice("/device:CPU:0"), c2.SerializeAsString()); + auto c_handle2 = ops::XRTCompile(root, computation2); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 compilation_handle1 = outputs[0].scalar()(); + int64 compilation_handle2 = outputs[1].scalar()(); + + Tensor release_tensor(DT_INT64, TensorShape({2})); + release_tensor.flat()(0) = compilation_handle1; + release_tensor.flat()(1) = compilation_handle2; + + auto release = ops::XRTReleaseCompilationHandle(root, release_tensor); + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + +TEST(RawApiTest, AllocAndClearAll) { + xrt::XLAAllocation alloc; + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({handle}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + int64 allocation_handle = outputs[0].scalar()(); + + auto clear_all = ops::XRTReleaseAllAllocations(root); + + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, + {clear_all}, &outputs)); + EXPECT_EQ(outputs.size(), 0); + + auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle)); + EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(), + tensorflow::error::Code::NOT_FOUND); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = TwoElementTuple(); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -230,7 +556,6 @@ TEST(RawApiTest, ReadAndWriteState) { TEST(RawApiTest, ReadAndWriteStateAutoFree) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = TwoElementTuple(); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -251,7 +576,6 @@ TEST(RawApiTest, ReadAndWriteStateAutoFree) { TEST(RawApiTest, SubBuffer) { xrt::XLAAllocation alloc; - alloc.set_device_ordinal(0); *alloc.mutable_value() = NestedTuple(); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -292,10 +616,8 @@ TEST(RawApiTest, SubBuffer) { TEST(RawApiTest, MakeTuple) { xrt::XLAAllocation alloc_0; - alloc_0.set_device_ordinal(0); *alloc_0.mutable_value() = TwoElementTuple(); xrt::XLAAllocation alloc_1; - alloc_1.set_device_ordinal(0); *alloc_1.mutable_value() = ScalarLiteral(); // The trivial tuple that just forwards its input and releases it. @@ -366,10 +688,8 @@ TEST(RawApiTest, MakeTuple) { TEST(RawApiTest, CompileAndExecute) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatVector({1.0f, 2.0f}); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatVector({8.0f, 5.0f}); xrt::XLAComputation c; @@ -421,10 +741,8 @@ TEST(RawApiTest, CompileAndExecute) { TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatVector({1.0f, 2.0f}); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatVector({8.0f, 5.0f}); xrt::XLAComputation c; @@ -544,10 +862,8 @@ TEST(RawApiTest, DotGeneralWithLayoutTest) { auto layout = xla::LayoutUtil::MakeLayout({0, 1}); xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout); xrt::XLAComputation c; @@ -630,10 +946,8 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) { TEST(RawApiTest, CompileAndExecuteReturnTuple) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = FloatVector({1.0f, 2.0f}); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = FloatVector({8.0f, 5.0f}); xrt::XLAComputation c; @@ -681,6 +995,68 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { + xrt::XLAAllocation p0; + *p0.mutable_value() = xla::LiteralUtil::CreateR0(12.0f).ToProto(); + + xrt::XLAAllocation p1; + *p1.mutable_value() = xla::LiteralUtil::CreateR0(3.0f).ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}), + xla::ShapeUtil::MakeShape(xla::F32, {})}) + .ToProto(); + StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + e.set_return_exploded_tuple(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle)}); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({result}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + auto handles_vec = outputs.front().vec(); + EXPECT_EQ(handles_vec.size(), 2); + + const float kResults[2] = {15.0f, 9.0f}; + for (int64 i = 0; i < handles_vec.size(); ++i) { + auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i))); + std::vector voutputs; + TF_EXPECT_OK(session.Run({read_back}, &voutputs)); + EXPECT_EQ(voutputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR0(kResults[i]); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + } +} + TEST(RawApiTest, LeakCompilationReference) { xrt::XLAComputation c; auto config = c.mutable_config(); @@ -705,12 +1081,111 @@ TEST(RawApiTest, LeakCompilationReference) { TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs)); } +TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) { + xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::F32, {2}); + xla::Shape shape = + xla::ShapeUtil::MakeTupleShape({element_shape, element_shape}); + xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape( + {element_shape, element_shape, element_shape, element_shape}); + xla::XlaBuilder builder("ReuseBuffer"); + auto param = xla::Parameter(&builder, 0, shape, "param"); + auto p0 = xla::GetTupleElement(param, 0); + auto p1 = xla::GetTupleElement(param, 1); + auto add = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + xla::Tuple(&builder, {add, sub, p0, p1}); + + // Flip the tuple literals in the input handle. + builder.SetUpAlias({1}, 0, {0}); + builder.SetUpAlias({0}, 0, {1}); + + auto computation = builder.Build().ValueOrDie(); + + auto literal0 = xla::LiteralUtil::CreateR1({1.0f, 2.0f}); + auto literal1 = xla::LiteralUtil::CreateR1({5.0f, 9.0f}); + auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1}); + + xrt::XLAAllocation param_alloc; + *param_alloc.mutable_value() = literal.ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = shape.ToProto(); + *shapes->mutable_result() = return_shape.ToProto(); + StoreComputationSnapshot(computation, c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(false); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + ClientSession session(root); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto c_data = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, c_data); + auto param_value = ops::Const(root.WithDevice("/device:CPU:0"), + param_alloc.SerializeAsString()); + auto param_handle = ops::XRTAllocate(root, param_value); + TF_ASSERT_OK(root.status()); + + std::vector outputs; + TF_EXPECT_OK(session.Run({param_handle}, &outputs)); + + int64 alloc_handle = outputs[0].scalar()(); + + // Note that we release the result handle immediately, but since we aliased + // the output buffers onto the input allocation ones (held in alloc_handle), + // we can fetch the result from there. + auto result = + ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)}); + auto read_back = ops::XRTReadLiteral(root, result); + auto release = ops::XRTReleaseAllocationHandle( + root.WithControlDependencies(read_back), result); + TF_ASSERT_OK(root.status()); + + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back}, + {release}, &outputs)); + + xla::Literal exec_literal = ReadOutputLiteral(outputs, 0); + auto exec_literal_parts = exec_literal.DecomposeTuple(); + ASSERT_EQ(exec_literal_parts.size(), 4); + + EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0)); + EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1)); + + // Now we read back the original input handle values, which at this point + // should contain the result of the XLA computation. + auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle)); + TF_ASSERT_OK(root.status()); + auto release_handle = ops::XRTReleaseAllocationHandle( + root.WithControlDependencies(read_handle), Input(alloc_handle)); + TF_ASSERT_OK(root.status()); + + outputs.clear(); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_handle}, + {release_handle}, &outputs)); + + xla::Literal return_literal = ReadOutputLiteral(outputs, 0); + + auto expected_literal0 = xla::LiteralUtil::CreateR1({6.0f, 11.0f}); + auto expected_literal1 = xla::LiteralUtil::CreateR1({-4.0f, -7.0f}); + // The first element of the computation returned tuple would be the add + // (expected_literal0), but since we flipped the buffers, the sub + // (expected_literal1) should come first. + auto expected_literal = + xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0}); + + EXPECT_TRUE(CompareLiterals(return_literal, expected_literal)); +} + TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xrt::XLAAllocation p0; - p0.set_device_ordinal(0); *p0.mutable_value() = xla::LiteralUtil::CreateR0(11031965).ToProto(); xrt::XLAAllocation p1; - p1.set_device_ordinal(0); *p1.mutable_value() = xla::LiteralUtil::CreateR0(4091934).ToProto(); xrt::XLAComputation c; @@ -724,6 +1199,7 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xrt::XRTExecutionConfig e; e.set_release_input_handles(true); e.set_release_compilation_handle(true); + e.set_return_exploded_tuple(true); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); auto e_config = diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index e149f2f43593ea412ef279b2c99dabac285cdac4..84adee7392825d408dd88dd74dc0c1bc7b06d7c4 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -59,7 +59,7 @@ message XLAComputation { // Literal to allocate space for, and transfer to, device memory. message XLAAllocation { - int32 device_ordinal = 1; + reserved 1; xla.LiteralProto value = 2; } @@ -101,4 +101,8 @@ message XRTExecutionConfig { bool release_input_handles = 5; // If true, release the handle to the computation after running. bool release_compilation_handle = 6; + // If set to true, and the result shape is a tuple, then instead of returning + // a single tuple allocation the execution will return a vector of + // allocations, one for each of the first-level elements of the result tuple. + bool return_exploded_tuple = 7; } diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc index d1405eae468492748ae88d842334a922dce272c6..8bf0f28d2233d9e7593365bc42187e327a1c4ac4 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.cc +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc @@ -273,6 +273,8 @@ Status XRTCompilationCache::Lookup( return Status::OK(); } -string XRTCompilationCache::DebugString() { return "XRTCompilationCache"; } +string XRTCompilationCache::DebugString() const { + return "XRTCompilationCache"; +} } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h index c43d0fc47873abdc82ee937c155bebc346a05f17..7398e847d8b744f947adb03e1bcfd5c0a5b2cc55 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.h +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.h @@ -118,7 +118,7 @@ class XRTCompilationCache : public ResourceBase { // EntryRef holding the program is returned in entry. Status Lookup(int64 uid, std::unique_ptr* entry); - string DebugString() override; + string DebugString() const override; private: // An entry in the compilation cache. The entry is deleted once it has been diff --git a/tensorflow/compiler/xrt/xrt_device.cc b/tensorflow/compiler/xrt/xrt_device.cc index ea40e6c895c4f6af13b74735685f2c342181ada9..34cb64742a20985b29d8e153bbaf5ee184fd385d 100644 --- a/tensorflow/compiler/xrt/xrt_device.cc +++ b/tensorflow/compiler/xrt/xrt_device.cc @@ -43,4 +43,12 @@ namespace tensorflow { return Status::OK(); } +/*static*/ Status XRTGenericDeviceAccessor::InitScopedRef( + OpKernelContext* ctx, ScopedRef* scoped_ref) { + const XlaDevice::Metadata* metadata; + TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); + scoped_ref->Acquire(metadata->client()); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_device.h b/tensorflow/compiler/xrt/xrt_device.h index 1e3fddd2a72a3657d1e115375133c244772ea9f3..fb010651d9bf76c540517b9596e472c881241d8a 100644 --- a/tensorflow/compiler/xrt/xrt_device.h +++ b/tensorflow/compiler/xrt/xrt_device.h @@ -59,6 +59,8 @@ class XRTGenericDeviceAccessor { static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref); + + static Status InitScopedRef(OpKernelContext* ctx, ScopedRef* scoped_ref); }; } // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 3a99820d7aa9e9546cc95385fd98c05f28988e9e..1e2a9584f88b73d7c92a929e93af60376a59170b 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_state.h" #include +#include #include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -41,6 +43,34 @@ namespace tensorflow { namespace { +class BufferAllocStats { + public: + struct Stats { + int64 count = 0; + int64 size = 0; + }; + + Stats ReportAlloc(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count += 1; + device_stats->size += msize; + return *device_stats; + } + + Stats ReportFree(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count -= 1; + device_stats->size -= msize; + return *device_stats; + } + + private: + mutable mutex lock_; + std::map stats_; +}; + const char* kTupleContainer = "tuples"; int64 get_uid() { @@ -48,6 +78,11 @@ int64 get_uid() { return static_cast(unsigned_rand); } +BufferAllocStats* GetAllocStats() { + static BufferAllocStats* stats = new BufferAllocStats(); + return stats; +} + Status AllocateScopedShapedBuffer( xla::Backend* backend, int device_ordinal, const xla::Shape& shape, std::unique_ptr* buffer) { @@ -98,11 +133,22 @@ Status AllocateScopedShapedBuffer( XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation, int device_ordinal, xla::DeviceMemoryAllocator* allocator) - : allocation_(allocation), + : size_(allocation.size()), + allocation_(allocation), device_ordinal_(device_ordinal), - allocator_(allocator) {} + allocator_(allocator) { + if (VLOG_IS_ON(2)) { + auto stats = + GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size()); + LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_ + << " count=" << stats.count << " size=" << stats.size; + } +} XRTBufferAllocation::~XRTBufferAllocation() { + if (VLOG_IS_ON(2)) { + GetAllocStats()->ReportFree(device_ordinal_, allocation_.size()); + } // Deallocate explicitly allows allocation_ to be null. Status s = allocator_->Deallocate(device_ordinal_, allocation_); // Nothing to do but check fail here if memory datastructures are corrupted. @@ -136,7 +182,7 @@ XRTTupleAllocation::~XRTTupleAllocation() { } /*static*/ Status XRTTupleAllocation::CreateAndTransfer( - const xla::Literal& literal, xla::Backend* backend, int device_ordinal, + const xla::LiteralBase& literal, xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation) { auto transfer_manager = backend->transfer_manager(); auto allocator = backend->memory_allocator(); @@ -178,11 +224,36 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, xla::Literal* literal) { auto transfer_manager = backend->transfer_manager(); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); + + // Validate the allocation buffers as if nulls gets to + // TransferLiteralFromDevice() a CHECK is issued. + xla::ShapedBuffer shaped_buffer = ToShapedBuffer(); + for (auto& index_buffer : shaped_buffer.buffers()) { + if (index_buffer.second.is_null()) { + return errors::InvalidArgument("Literal buffer at index ", + index_buffer.first.ToString(), + " has been released"); + } + } TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice( - stream.get(), ToShapedBuffer())); + stream.get(), shaped_buffer)); return Status::OK(); } +Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, + const xla::Literal& literal) { + if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) { + return errors::InvalidArgument( + "New literal shape not matching the existing one: literal=", + xla::ShapeUtil::HumanStringWithLayout(literal.shape()), + " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape())); + } + auto transfer_manager = backend->transfer_manager(); + TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); + return transfer_manager->TransferLiteralToDevice(stream.get(), literal, + ToShapedBuffer()); +} + void XRTTupleAllocation::DiscardAllocation( const xla::ShapeIndex& buffer_index) { buffers_.element(buffer_index)->DiscardAllocation(); @@ -213,6 +284,11 @@ const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() { return rm->Delete(kTupleContainer, key_string); } +/* static */ Status XRTTupleAllocation::ReleaseAllAllocations(ResourceMgr* rm) { + VLOG(1) << "Releasing all XRT held device memory"; + return rm->Cleanup(kTupleContainer); +} + // Helper typedef to make ShapeTree ForEach helper lambda signatures more // readable. They need a type of const T& where in this case T is the // following pointer. @@ -441,11 +517,34 @@ xla::ShapedBuffer XRTTupleAllocation::ToShapedBuffer() { return shaped_buffer; } +Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source, + const xla::ShapeIndex& source_index, + const xla::ShapeIndex& dest_index) { + XRTBufferAllocation* source_buffer = source.buffers_.element(source_index); + XRTBufferAllocation* dest_buffer = buffers_.element(dest_index); + // We allow the destination size being zero, because there are cases where we + // are coming in later filling in null/uninitialized device buffers. + // In all other cases, the size of the new buffer must match. + if (source_buffer->size() != dest_buffer->size() && + dest_buffer->size() != 0) { + return errors::InvalidArgument( + "Source buffer at index ", source_index.ToString(), + " does not match the size of destination buffer at index ", + dest_index.ToString(), ": ", source_buffer->size(), " vs ", + dest_buffer->size()); + } + *buffers_.mutable_element(dest_index) = source_buffer; + source_buffer->Ref(); + dest_buffer->Unref(); + return Status::OK(); +} + xla::ShapeTree -XRTTupleAllocation::ToDeviceMemoryTree(bool release) { +XRTTupleAllocation::ToDeviceMemoryTree( + const std::function& release_checker) { xla::ShapeTree shaped_tree(on_device_shape()); for (const auto& buffer : buffers_) { - if (!release) { + if (!release_checker(buffer.first)) { *shaped_tree.mutable_element(buffer.first) = buffer.second->allocation(); } else { *shaped_tree.mutable_element(buffer.first) = xla::OwningDeviceMemory( diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 73b5584e38f781343fe6793af7ad28232fbfc184..ddf2656e6f51775024a6d1cd0d7a387605faae6f 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ #define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ +#include #include #include #include @@ -58,7 +59,14 @@ class XRTBufferAllocation : public core::RefCounted { // freed when the reference count drops to zero. void DiscardAllocation(); + // Returns the expected size of the allocation. Since DiscardAllocation() will + // set allocation_ to {null,0}, and since later we might want to replace the + // discarded buffer with a new one, we need to be able to verify the size + // compatibility. + uint64 size() const { return size_; } + private: + uint64 size_ = 0; se::DeviceMemoryBase allocation_; int device_ordinal_; xla::DeviceMemoryAllocator* allocator_; @@ -80,7 +88,7 @@ class XRTTupleAllocation : public ResourceBase { // Allocates new device memory buffers sufficient to store literal, transfers // literal to that memory, and returns a XRTTupleAllocation handle to the // allocated buffers. - static Status CreateAndTransfer(const xla::Literal& literal, + static Status CreateAndTransfer(const xla::LiteralBase& literal, xla::Backend* backend, int device_ordinal, XRTTupleAllocation** allocation); @@ -129,6 +137,10 @@ class XRTTupleAllocation : public ResourceBase { // Deletes the reference in the rm to an allocation interned under key. static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key); + // Releases all the device memory allocated by XRT within the resource + // manager. + static Status ReleaseAllAllocations(ResourceMgr* rm); + // Adds the allocation to a ResourceMgr and returns the key that will be used // to retrieve it. Transfers a reference on *this to rm. Status Intern(ResourceMgr* rm, int64* key); @@ -137,6 +149,9 @@ class XRTTupleAllocation : public ResourceBase { Status ToLiteral(xla::Backend* backend, int device_ordinal, xla::Literal* literal); + // Write a new literal value to the allocation. + Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); + // True if none of the buffers in the allocation are aliased by any other live // handle. bool IsExclusiveOwner(); @@ -161,11 +176,20 @@ class XRTTupleAllocation : public ResourceBase { // the same shape as on_host_shape. xla::ShapedBuffer ToShapedBuffer(); - // Returns the device memory tree of this allocation. If 'release' is set, the - // ownership of the device memory is transferred to the result. - xla::ShapeTree ToDeviceMemoryTree(bool release); + // Aliases the source buffer at source_index into the current tuple allocation + // dest_index. + Status AliasBufferFrom(const XRTTupleAllocation& source, + const xla::ShapeIndex& source_index, + const xla::ShapeIndex& dest_index); + + // Returns the device memory tree of this allocation. If the release_checker + // function returns true for a given index, the ownership of the device memory + // at that index is transferred to the result. Every attempt to read the value + // at that index will fail. + xla::ShapeTree ToDeviceMemoryTree( + const std::function& release_checker); - string DebugString() override { return "XLA allocation handle"; } + string DebugString() const override { return "XLA allocation handle"; } private: // Creates a new handle with (tuple) shape. diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 832db0f4ab46911e067d17b4a125706c276cf798..25f2640e35af5f65eab25dc60c44e3ed7ce4e512 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -63,7 +63,6 @@ py_library( "//tensorflow/contrib/libsvm", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", - "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", @@ -197,7 +196,7 @@ cc_library( "//tensorflow/contrib/kinesis:dataset_kernels", ], }) + if_not_windows([ - "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", + "//tensorflow/contrib/tensorrt:trt_op_kernels", ]), ) @@ -239,7 +238,7 @@ cc_library( "//tensorflow/contrib/kinesis:dataset_ops_op_lib", ], }) + if_not_windows([ - "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", + "//tensorflow/compiler/tf2tensorrt:trt_op_libs", ]) + select({ "//tensorflow:android": [], "//tensorflow:ios": [], diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 4f1a2a5693235183c8f486817b82c8c81fa389ec..48d5296c71cbdb470fa405b30547a32b7022f29b 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -20,13 +20,14 @@ from __future__ import division from __future__ import print_function import os +import platform # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import autograph from tensorflow.contrib import batching from tensorflow.contrib import bayesflow from tensorflow.contrib import checkpoint -if os.name != "nt": +if os.name != "nt" and platform.machine() != "s390x": from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder @@ -91,7 +92,6 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager -from tensorflow.contrib.lite.python import lite from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2 from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.recurrent.python import recurrent_api as recurrent @@ -103,6 +103,8 @@ from tensorflow.python.util.lazy_loader import LazyLoader ffmpeg = LazyLoader("ffmpeg", globals(), "tensorflow.contrib.ffmpeg") del os +del platform + del LazyLoader del absolute_import diff --git a/tensorflow/contrib/android/cmake/build.gradle b/tensorflow/contrib/android/cmake/build.gradle index 17a57b99fd6c9efc09bda0ce1249b1f51bd5af5c..ddec08894f34f96b080610f1d27a6a436f7ffa91 100644 --- a/tensorflow/contrib/android/cmake/build.gradle +++ b/tensorflow/contrib/android/cmake/build.gradle @@ -22,8 +22,8 @@ android { } externalNativeBuild { cmake { - arguments '-DANDROID_TOOLCHAIN=gcc', - '-DANDROID_STL=gnustl_static' + arguments '-DANDROID_TOOLCHAIN=clang', + '-DANDROID_STL=c++_static' } } } @@ -70,7 +70,7 @@ if (ndkDir == null || ndkDir == "") { ndkDir = System.getenv('ANDROID_NDK_HOME') } -if(! Os.isFamily(Os.FAMILY_WINDOWS)) { +if (!Os.isFamily(Os.FAMILY_WINDOWS)) { // This script is for non-Windows OS. For Windows OS, MANUALLY build // (or copy the built) libs/headers to the // ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb index 44532cb078f9bd1578172f8a7d8a4b55cd21a7cb..831c613f2c8c9a4fcc2cb9d313077fe79ee96fd7 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb @@ -186,8 +186,8 @@ "\n", " def __init__(self):\n", " super(RnnColorbot, self).__init__()\n", - " self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", - " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", + " self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256, dtype=tf.float32)\n", + " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128, dtype=tf.float32)\n", " self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", "\n", " def _rnn_layer(self, chars, cell, batch_size, training):\n", @@ -241,7 +241,7 @@ " seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n", "\n", " # Grab just the end-of-sequence from each output.\n", - " indices = (length - 1, range(batch_size))\n", + " indices = (length - 1, list(range(batch_size)))\n", " indices = tf.stack(indices, 1)\n", " sequence_ends = tf.gather_nd(seq, indices)\n", " return self.relu_layer(sequence_ends)\n", diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index 648f3ebb05646a66144bcb118347cbc391909409..5174afe0a63d37e3ea3e19ac9bab644d1d83ecf1 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -37,6 +37,7 @@ py_library( cc_library( name = "batch_ops_kernels", deps = [ + "//tensorflow/core:batch_ops_op_lib", "//tensorflow/core/kernels:batch_kernels", ], alwayslink = 1, diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md index 2c44abed5e1955cc666273e97e6b2378766f13d2..79052bee35c7895cb4048b10c1f73acb036d1587 100644 --- a/tensorflow/contrib/bigtable/README.md +++ b/tensorflow/contrib/bigtable/README.md @@ -51,25 +51,18 @@ BIGTABLE_TABLE_NAME = '' PREFIX = 'train-' def main(): + tf.enable_eager_execution() + client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID) table = client.table(BIGTABLE_TABLE_NAME) dataset = table.keys_by_prefix_dataset(PREFIX) - iterator = dataset.make_initializable_iterator() - get_next_op = iterator.get_next() - with tf.Session() as sess: - print('Initializing the iterator.') - sess.run(iterator.initializer) - print('Retrieving rows:') - row_index = 0 - while True: - try: - row_key = sess.run(get_next_op) - print('Row key %d: %s' % (row_index, row_key)) - row_index += 1 - except tf.errors.OutOfRangeError: - print('Finished reading data!') - break + print('Retrieving rows:') + row_index = 0 + for row_key in dataset: + print('Row key %d: %s' % (row_index, row_key)) + row_index += 1 + print('Finished reading data!') if __name__ == '__main__': main() diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h index 4652021fecabfa11fa6a8754dc884d89e151b590..e3b4535bac4a01a1277290e0d1ea6d3c7613731c 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -42,7 +42,7 @@ class BigtableClientResource : public ResourceBase { return client_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("BigtableClientResource(project_id: ", project_id_, ", instance_id: ", instance_id_, ")"); } @@ -67,7 +67,7 @@ class BigtableTableResource : public ResourceBase { ::google::cloud::bigtable::noex::Table& table() { return table_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat( "BigtableTableResource(client: ", client_->DebugString(), ", table: ", table_name_, ")"); diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index e95dc577184f7e81d942755b41065f52131ce9f6..e6fda9e61757f1441b3691c2a3d57c6f1a5a0d42 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" -#include "google/bigtable/v2/data.pb.h" +#include "external/com_github_googleapis_googleapis/google/bigtable/v2/data.pb.h" #include "google/protobuf/wrappers.pb.h" #include "re2/re2.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -399,6 +399,28 @@ BigtableTestClient::AsyncMutateRows( return nullptr; } +std::unique_ptr> +BigtableTestClient::AsyncCheckAndMutateRow( + grpc::ClientContext* context, + const google::bigtable::v2::CheckAndMutateRowRequest& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + +std::unique_ptr< + grpc::ClientAsyncReaderInterface> +BigtableTestClient::AsyncReadRows( + grpc::ClientContext* context, + const google::bigtable::v2::ReadRowsRequest& request, + grpc::CompletionQueue* cq, void* tag) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::shared_ptr BigtableTestClient::Channel() { LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " "cause a crash!"; diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index c4a1f06bc504c3565c7bb09b42e48e7fbddb9cc6..8e1326f2ce841368ea81fc7194a0588e5d6cd637 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -80,6 +80,19 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { const ::google::bigtable::v2::MutateRowsRequest& request, ::grpc::CompletionQueue* cq, void* tag) override; + std::unique_ptr> + AsyncCheckAndMutateRow( + grpc::ClientContext* context, + const google::bigtable::v2::CheckAndMutateRowRequest& request, + grpc::CompletionQueue* cq) override; + + std::unique_ptr< + grpc::ClientAsyncReaderInterface> + AsyncReadRows(grpc::ClientContext* context, + const google::bigtable::v2::ReadRowsRequest& request, + grpc::CompletionQueue* cq, void* tag) override; + std::shared_ptr Channel() override; private: diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py index 316da9ebe152ef52c7e7f846cf8c3eb1555ee8a6..197f5578eb010bee5a3aad7c05446393193f99e2 100644 --- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -57,7 +57,7 @@ class BigtableOpsTest(test.TestCase): sess.run(write_op) def runReadKeyTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected = list(self.COMMON_ROW_KEYS) expected.reverse() @@ -78,7 +78,7 @@ class BigtableOpsTest(test.TestCase): self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) def runScanTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_keys.reverse() @@ -120,7 +120,7 @@ class BigtableOpsTest(test.TestCase): def testLookup(self): ds = self._table.keys_by_prefix_dataset("r") ds = ds.apply(self._table.lookup_columns(cf1="c1")) - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_values = list(self.COMMON_VALUES) @@ -141,7 +141,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeys(self): ds = self._table.sample_keys() - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_key = self.COMMON_ROW_KEYS[0] with self.cached_session() as sess: @@ -161,7 +161,7 @@ class BigtableOpsTest(test.TestCase): sess.run(n) def runSampleKeyPairsTest(self, ds, expected_key_pairs): - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -218,7 +218,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndStartKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="r1", end="") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) @@ -226,14 +226,14 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndEndKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="", end="r3") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) def testParallelScanPrefix(self): ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -251,7 +251,7 @@ class BigtableOpsTest(test.TestCase): def testParallelScanRange(self): ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index 7c87b0daeb09950cc44c51f49c16534d413f0376..fa64055dfd65a134afdf46cebccb7f7d96106502 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -35,8 +35,8 @@ from tensorflow.contrib.util import loader from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import resource_loader @@ -111,8 +111,7 @@ class BigtableClient(object): class BigtableTable(object): - """BigtableTable is the entrypoint for reading and writing data in Cloud - Bigtable. + """Entry point for reading and writing data in Cloud Bigtable. This BigtableTable class is the Python representation of the Cloud Bigtable table within TensorFlow. Methods on this class allow data to be read from and @@ -222,7 +221,7 @@ class BigtableTable(object): A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all of the row keys matching that prefix. """ - return _BigtablePrefixKeyDataset(self, prefix) + return dataset_ops.DatasetV1Adapter(_BigtablePrefixKeyDataset(self, prefix)) def sample_keys(self): """Retrieves a sampling of row keys from the Bigtable table. @@ -234,7 +233,7 @@ class BigtableTable(object): Returns: A `tf.data.Dataset` returning string row keys. """ - return _BigtableSampleKeysDataset(self) + return dataset_ops.DatasetV1Adapter(_BigtableSampleKeysDataset(self)) def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): """Retrieves row (including values) from the Bigtable service. @@ -279,7 +278,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return _BigtableScanDataset(self, prefix, "", "", normalized, probability) + return dataset_ops.DatasetV1Adapter( + _BigtableScanDataset(self, prefix, "", "", normalized, probability)) def scan_range(self, start, end, probability=None, columns=None, **kwargs): """Retrieves rows (including values) from the Bigtable service. @@ -324,7 +324,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return _BigtableScanDataset(self, "", start, end, normalized, probability) + return dataset_ops.DatasetV1Adapter( + _BigtableScanDataset(self, "", start, end, normalized, probability)) def parallel_scan_prefix(self, prefix, @@ -380,7 +381,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "") + ds = dataset_ops.DatasetV1Adapter( + _BigtableSampleKeyPairsDataset(self, prefix, "", "")) return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) @@ -442,7 +444,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = _BigtableSampleKeyPairsDataset(self, "", start, end) + ds = dataset_ops.DatasetV1Adapter( + _BigtableSampleKeyPairsDataset(self, "", start, end)) return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) @@ -486,7 +489,7 @@ class BigtableTable(object): "len(dataset.output_types))") return gen_bigtable_ops.dataset_to_bigtable( self._resource, - dataset._as_variant_tensor(), # pylint: disable=protected-access + dataset._variant_tensor, # pylint: disable=protected-access column_families, columns, timestamp) @@ -579,26 +582,19 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource): """_BigtableKeyDataset is an abstract class representing the keys of a table. """ - def __init__(self, table): + def __init__(self, table, variant_tensor): """Constructs a _BigtableKeyDataset. Args: table: a Bigtable class. + variant_tensor: DT_VARIANT representation of the dataset. """ - super(_BigtableKeyDataset, self).__init__() + super(_BigtableKeyDataset, self).__init__(variant_tensor) self._table = table @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.TensorShape([]) - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) class _BigtablePrefixKeyDataset(_BigtableKeyDataset): @@ -606,13 +602,11 @@ class _BigtablePrefixKeyDataset(_BigtableKeyDataset): """ def __init__(self, table, prefix): - super(_BigtablePrefixKeyDataset, self).__init__(table) self._prefix = prefix - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_prefix_key_dataset( - table=self._table._resource, # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_prefix_key_dataset( + table=table._resource, # pylint: disable=protected-access prefix=self._prefix) + super(_BigtablePrefixKeyDataset, self).__init__(table, variant_tensor) class _BigtableRangeKeyDataset(_BigtableKeyDataset): @@ -620,15 +614,13 @@ class _BigtableRangeKeyDataset(_BigtableKeyDataset): """ def __init__(self, table, start, end): - super(_BigtableRangeKeyDataset, self).__init__(table) self._start = start self._end = end - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_range_key_dataset( - table=self._table._resource, # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_range_key_dataset( + table=table._resource, # pylint: disable=protected-access start_key=self._start, end_key=self._end) + super(_BigtableRangeKeyDataset, self).__init__(table, variant_tensor) class _BigtableSampleKeysDataset(_BigtableKeyDataset): @@ -638,11 +630,9 @@ class _BigtableSampleKeysDataset(_BigtableKeyDataset): # TODO(saeta): Expose the data size offsets into the keys. def __init__(self, table): - super(_BigtableSampleKeysDataset, self).__init__(table) - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_sample_keys_dataset( - table=self._table._resource) # pylint: disable=protected-access + variant_tensor = gen_bigtable_ops.bigtable_sample_keys_dataset( + table=table._resource) # pylint: disable=protected-access + super(_BigtableSampleKeysDataset, self).__init__(table, variant_tensor) class _BigtableLookupDataset(dataset_ops.DatasetSource): @@ -656,26 +646,17 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource): self._normalized = normalized self._column_families = [i[0] for i in normalized] self._columns = [i[1] for i in normalized] - - @property - def output_classes(self): - return tuple([ops.Tensor] * self._num_outputs) - - @property - def output_shapes(self): - return tuple([tensor_shape.TensorShape([])] * self._num_outputs) - - @property - def output_types(self): - return tuple([dtypes.string] * self._num_outputs) - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_bigtable_ops.bigtable_lookup_dataset( - keys_dataset=self._dataset._as_variant_tensor(), - table=self._table._resource, + variant_tensor = gen_bigtable_ops.bigtable_lookup_dataset( + keys_dataset=self._dataset._variant_tensor, # pylint: disable=protected-access + table=self._table._resource, # pylint: disable=protected-access column_families=self._column_families, columns=self._columns) + super(_BigtableLookupDataset, self).__init__(variant_tensor) + + @property + def _element_structure(self): + return structure.NestedStructure(tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) class _BigtableScanDataset(dataset_ops.DatasetSource): @@ -691,21 +672,7 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): self._columns = [i[1] for i in normalized] self._probability = probability self._num_outputs = len(normalized) + 1 # 1 for row key - - @property - def output_classes(self): - return tuple([ops.Tensor] * self._num_outputs) - - @property - def output_shapes(self): - return tuple([tensor_shape.TensorShape([])] * self._num_outputs) - - @property - def output_types(self): - return tuple([dtypes.string] * self._num_outputs) - - def _as_variant_tensor(self): - return gen_bigtable_ops.bigtable_scan_dataset( + variant_tensor = gen_bigtable_ops.bigtable_scan_dataset( table=self._table._resource, # pylint: disable=protected-access prefix=self._prefix, start_key=self._start, @@ -713,6 +680,13 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): column_families=self._column_families, columns=self._columns, probability=self._probability) + super(_BigtableScanDataset, self).__init__(variant_tensor) + + @property + def _element_structure(self): + return structure.NestedStructure( + tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): @@ -724,23 +698,15 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): self._prefix = prefix self._start = start self._end = end - - @property - def output_classes(self): - return (ops.Tensor, ops.Tensor) - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) - - @property - def output_types(self): - return (dtypes.string, dtypes.string) - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_bigtable_ops.bigtable_sample_key_pairs_dataset( - table=self._table._resource, + variant_tensor = gen_bigtable_ops.bigtable_sample_key_pairs_dataset( + table=self._table._resource, # pylint: disable=protected-access prefix=self._prefix, start_key=self._start, end_key=self._end) + super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor) + + @property + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index d3b23d949ee2c7674c3918d39e8b71d76eefcfec..64e4c4560ba3a1b177db12a09997ff7afe8775a3 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -193,8 +193,9 @@ py_test( py_test( name = "estimator_test", - size = "large", + size = "medium", srcs = ["estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = [ "no_gpu", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index ee052ac60387d8f993e4942dd7dff39e191dd3a4..47d910d42a27db4b857eeb12209dfbb429dd1be2 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -487,8 +487,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper_0 <= 0.98) self.assertTrue(frac_below_upper_1 >= 0.92) self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.92) - self.assertTrue(frac_both_below_upper <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.91) + self.assertTrue(frac_both_below_upper <= 0.99) train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( two_dimension=True) @@ -516,8 +516,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_above_lower_0 <= 0.98) self.assertTrue(frac_above_lower_1 >= 0.92) self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.92) - self.assertTrue(frac_both_above_lower <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.91) + self.assertTrue(frac_both_above_lower <= 0.99) class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -806,8 +806,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper_0 <= 0.98) self.assertTrue(frac_below_upper_1 >= 0.92) self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.92) - self.assertTrue(frac_both_below_upper <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.91) + self.assertTrue(frac_both_below_upper <= 0.99) train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( two_dimension=True) @@ -835,8 +835,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): self.assertTrue(frac_above_lower_0 <= 0.98) self.assertTrue(frac_above_lower_1 >= 0.92) self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.92) - self.assertTrue(frac_both_above_lower <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.91) + self.assertTrue(frac_both_above_lower <= 0.99) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index e446c411a8d5075563b8f8b912b29df310e16c8c..6faf6963011b698a3b233329d87471da7608e44a 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -96,7 +96,7 @@ class StatsAccumulatorResource : public boosted_trees::StampedResource { TensorShapeUtils::IsScalar(hessian_shape)); } - string DebugString() override { + string DebugString() const override { return strings::StrCat("StatsAccumulatorResource[size=", values_.size(), "]"); } diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py index 42d69645acaae063fcd46bd1f6c819ccb68f48bd..aa3f24f08a0f762507df83def72e7d595265221f 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py @@ -227,7 +227,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): tree_ensemble_config=tree_ensemble_config.SerializeToString(), name="restore_tree") resources.initialize_resources(resources.shared_resources()).run() - variables.initialize_all_variables().run() + variables.global_variables_initializer().run() my_saver = saver.Saver() # Add the second tree and replace the ensemble of the handle. diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index fca22c71a83459cb290eaebcf107cf1c14c222b7..c3685b54e201f73039f6623443c67ba2b217a51e 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -62,8 +62,8 @@ class TreeEnsembleVariableSavable(saver.BaseSaverBuilder.SaveableObject): saver.BaseSaverBuilder.SaveSpec(ensemble_config, slice_spec, name + "_config"), ] - super(TreeEnsembleVariableSavable, - self).__init__(tree_ensemble_handle, specs, name) + super(TreeEnsembleVariableSavable, self).__init__(tree_ensemble_handle, + specs, name) self._tree_ensemble_handle = tree_ensemble_handle self._create_op = create_op @@ -115,7 +115,7 @@ class TreeEnsembleVariable(tracking.TrackableResource): def _gather_saveables_for_checkpoint(self): return { - "tree_ensemble_variable": + self.resource_handle.op.name + "/tree_ensemble_variable": functools.partial( TreeEnsembleVariableSavable, tree_ensemble_handle=self.resource_handle, @@ -131,8 +131,8 @@ def tree_ensemble_variable(stamp_token, Args: stamp_token: The initial stamp token value for the ensemble resource. - tree_ensemble_config: A `Tensor` of type `string`. - Serialized proto of the tree ensemble. + tree_ensemble_config: A `Tensor` of type `string`. Serialized proto of the + tree ensemble. name: A name for the ensemble variable. container: An optional `string`. Defaults to `""`. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 9fdc2fc0c2c7b85502f7a3f9ae7c85cf05d5916c..e78ec476ab3b43e5eb56a2502008bb8020ae97e0 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -566,9 +566,10 @@ class GradientBoostedDecisionTreeModel(object): # Determine if ensemble is colocated with the inputs. if self._ensemble_handle.device != input_deps[0].device: # Create a local ensemble and get its local stamp. - with ops.name_scope("local_ensemble", "TreeEnsembleVariable") as name: + with ops.name_scope("local_ensemble", "TreeEnsembleVariable"): local_ensemble_handle = ( - gen_model_ops.decision_tree_ensemble_resource_handle_op(name=name)) + gen_model_ops.decision_tree_ensemble_resource_handle_op( + self._ensemble_handle.op.name + "/local_ensemble")) create_op = gen_model_ops.create_tree_ensemble_variable( local_ensemble_handle, stamp_token=-1, tree_ensemble_config="") with ops.control_dependencies([create_op]): @@ -614,13 +615,19 @@ class GradientBoostedDecisionTreeModel(object): predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) return constant_op.constant(-1, dtype=dtypes.int32) - def update_stats(self, loss, predictions_dict): + def update_stats(self, loss, predictions_dict, gradients=None, hessians=None): """Update the accumulators with stats from this batch. Args: loss: A scalar tensor representing average loss of examples. predictions_dict: Dictionary of Rank 2 `Tensor` representing information about predictions per example. + gradients: A tensor with the gradients with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. + hessians: A tensor with the hessians with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. Returns: Three values: @@ -642,13 +649,14 @@ class GradientBoostedDecisionTreeModel(object): predictions = predictions_dict[PREDICTIONS] partition_ids = predictions_dict[PARTITION_IDS] ensemble_stamp = predictions_dict[ENSEMBLE_STAMP] - gradients = gradients_impl.gradients( - loss, - predictions, - name="Gradients", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if gradients is None: + gradients = gradients_impl.gradients( + loss, + predictions, + name="Gradients", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy class_id = self._get_class_id(predictions_dict) @@ -657,17 +665,20 @@ class GradientBoostedDecisionTreeModel(object): # We build one vs rest trees. if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. - hessians = gradients_impl.gradients( - gradients, - predictions, - name="Hessian", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if hessians is None: + hessians = gradients_impl.gradients( + gradients, + predictions, + name="Hessian", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] squeezed_gradients = array_ops.squeeze(gradients, axis=[1]) squeezed_hessians = array_ops.squeeze(hessians, axis=[1]) else: + if hessians is not None: + raise ValueError("Providing hessians is not yet supported here.") hessian_list = self._diagonal_hessian(gradients, predictions) # Assemble hessian list into a tensor. hessians = array_ops.stack(hessian_list, axis=1) @@ -678,6 +689,8 @@ class GradientBoostedDecisionTreeModel(object): squeezed_hessians = array_ops.squeeze( _get_column_by_index(hessians, class_id)) else: + if hessians is not None: + raise ValueError("Providing hessians is not yet supported here.") # Other multiclass strategies. if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: hessian_list = self._full_hessian(gradients, predictions) @@ -835,9 +848,9 @@ class GradientBoostedDecisionTreeModel(object): stats_update_ops.append( control_flow_ops.cond( continue_centering, - self._make_update_bias_stats_fn( - ensemble_stamp, predictions, gradients, - bias_stats_accumulator), control_flow_ops.no_op)) + self._make_update_bias_stats_fn(ensemble_stamp, predictions, + gradients, bias_stats_accumulator, + hessians), control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -1162,7 +1175,8 @@ class GradientBoostedDecisionTreeModel(object): def get_max_tree_depth(self): return self._max_tree_depth - def train(self, loss, predictions_dict, labels): + def train(self, loss, predictions_dict, labels, gradients=None, + hessians=None): """Updates the accumalator stats and grows the ensemble. Args: @@ -1171,6 +1185,12 @@ class GradientBoostedDecisionTreeModel(object): about predictions per example. labels: Rank 2 `Tensor` representing labels per example. Has no effect on the training and is only kept for backward compatibility. + gradients: A tensor with the gradients with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. + hessians: A tensor with the hessians with the respect to logits from + predictions_dict. If not provided, tensorflow will do + autodifferentiation. Returns: An op that adds a new tree to the ensemble. @@ -1179,7 +1199,8 @@ class GradientBoostedDecisionTreeModel(object): ValueError: if inputs are not valid. """ del labels # unused; kept for backward compatibility. - update_op, _, training_state = self.update_stats(loss, predictions_dict) + update_op, _, training_state = self.update_stats(loss, predictions_dict, + gradients, hessians) with ops.control_dependencies(update_op): return self.increment_step_counter_and_maybe_update_ensemble( predictions_dict, training_state) @@ -1271,21 +1292,28 @@ class GradientBoostedDecisionTreeModel(object): ps_ops=ps_ops, ps_strategy=ps_strategy) - def _make_update_bias_stats_fn(self, ensemble_stamp, predictions, gradients, - bias_stats_accumulator): + def _make_update_bias_stats_fn(self, + ensemble_stamp, + predictions, + gradients, + bias_stats_accumulator, + hessians=None): """A method to create the function which updates the bias stats.""" def _update_bias_stats(): """A method to update the bias stats.""" # Get reduced gradients and hessians. grads_sum = math_ops.reduce_sum(gradients, 0) - hess = gradients_impl.gradients( - grads_sum, - predictions, - name="Hessians", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] + if hessians is not None: + hess = hessians + else: + hess = gradients_impl.gradients( + grads_sum, + predictions, + name="Hessians", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] hess_sum = math_ops.reduce_sum(hess, 0) # Accumulate gradients and hessians. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 92068e88a76cb8bfdd394c1093347a8fb8a63449..7e45d0b2cecefa4bdec77d6cf7cfca7dba04db9c 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -43,7 +43,7 @@ from tensorflow.python.platform import googletest def _squared_loss(label, unused_weights, predictions): """Unweighted loss implementation.""" loss = math_ops.reduce_sum( - math_ops.square(predictions - label), 1, keepdims=True) + math_ops.squared_difference(predictions, label), 1, keepdims=True) return loss diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index 220e981618b7c0bfb1e4e98c087d83b451b9b3cf..1ad40aca2880940c78d746674c7378ff0427c057 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -166,7 +166,7 @@ def per_example_squared_loss(labels, weights, predictions): update_op: An update operation to update the loss's internal state. """ unweighted_loss = math_ops.reduce_sum( - math_ops.square(predictions - labels), 1, keepdims=True) + math_ops.squared_difference(predictions, labels), 1, keepdims=True) return unweighted_loss * weights, control_flow_ops.no_op() diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 94aeb2c7bb48c6eddb6c7894f8bf6f1567470113..0fe57c0a4e8375cc7ec7aca9553bded87e238b33 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -34,7 +34,7 @@ class DecisionTreeEnsembleResource : public StampedResource { protobuf::Arena::CreateMessage< boosted_trees::trees::DecisionTreeEnsembleConfig>(&arena_)) {} - string DebugString() override { + string DebugString() const override { return strings::StrCat("GTFlowDecisionTreeEnsemble[size=", decision_tree_ensemble_->trees_size(), "]"); } diff --git a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h index fdaaae7f472c8f564ab45a8366d3746cbf1158ee..574e3065e7f46049815897ef73e44d33f0d23f0f 100644 --- a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h @@ -43,7 +43,7 @@ class QuantileStreamResource : public StampedResource { set_stamp(stamp_token); } - string DebugString() override { return "QuantileStreamResource"; } + string DebugString() const override { return "QuantileStreamResource"; } tensorflow::mutex* mutex() { return &mu_; } diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 94b7f4f867655bf7fdf94e8488eeae7088c41622..99ed4959fad9699f265183d71a1f3b609d7e6d30 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -51,11 +51,11 @@ from tensorflow.contrib.checkpoint.python.split_dependency import split_dependen from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph from tensorflow.python.training.checkpoint_management import CheckpointManager -from tensorflow.python.training.checkpointable.base import CheckpointableBase +from tensorflow.python.training.checkpointable.base import Checkpointable as CheckpointableBase from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.data_structures import Mapping from tensorflow.python.training.checkpointable.data_structures import NoDependency -from tensorflow.python.training.checkpointable.tracking import Checkpointable +from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable from tensorflow.python.training.checkpointable.util import capture_dependencies from tensorflow.python.training.checkpointable.util import list_objects from tensorflow.python.training.checkpointable.util import object_metadata diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index ada41687261ab63286933d01da4e286173042e0c..4e529322c7c76797938468b405cd175609dc0a73 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -2,7 +2,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( name = "checkpoint", @@ -27,17 +27,17 @@ py_library( ], ) -py_test( +tf_py_test( name = "containers_test", srcs = ["containers_test.py"], - deps = [ + additional_deps = [ ":containers", + "@six_archive//:six", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", "//tensorflow/python/training/checkpointable:base", "//tensorflow/python/training/checkpointable:util", - "@six_archive//:six", ], ) @@ -53,18 +53,18 @@ py_library( ], ) -py_test( +tf_py_test( name = "python_state_test", srcs = ["python_state_test.py"], - deps = [ + additional_deps = [ ":python_state", + "//third_party/py/numpy", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:session", "//tensorflow/python:variables", "//tensorflow/python/eager:test", "//tensorflow/python/training/checkpointable:util", - "//third_party/py/numpy", ], ) @@ -80,10 +80,10 @@ py_library( ], ) -py_test( +tf_py_test( name = "split_dependency_test", srcs = ["split_dependency_test.py"], - deps = [ + additional_deps = [ ":split_dependency", "//tensorflow/python:array_ops", "//tensorflow/python:framework_test_lib", @@ -106,10 +106,10 @@ py_library( ], ) -py_test( +tf_py_test( name = "visualize_test", srcs = ["visualize_test.py"], - deps = [ + additional_deps = [ ":visualize", "//tensorflow/python:constant_op", "//tensorflow/python:resource_variable_ops", diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 5418e2605b724edb60878e250d2c50fcc6ff5633..97936d9e9dfd5d6e62fdf8312707a276b63e1267 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -63,7 +63,7 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): ValueError: If `checkpointable` is not a checkpointable object. """ - if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase): + if not isinstance(checkpointable, checkpointable_lib.Checkpointable): raise ValueError( ("Expected a checkpointable value, got %s which does not inherit " "from CheckpointableBase.") % (checkpointable,)) diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index ac85c7be803cd4c2f8ba19d3ef887a3c65a15933..a2d453ec6eb3dcf9aba4c52fe866756a92673c63 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -52,7 +52,7 @@ class UniqueNameTrackerTests(test.TestCase): save_root = util.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) - restore_slots = tracking.Checkpointable() + restore_slots = tracking.AutoCheckpointable() restore_root = util.Checkpoint( slots=restore_slots) status = restore_root.restore(save_path) @@ -68,7 +68,7 @@ class UniqueNameTrackerTests(test.TestCase): @test_util.run_in_graph_and_eager_modes def testExample(self): - class SlotManager(tracking.Checkpointable): + class SlotManager(tracking.AutoCheckpointable): def __init__(self): self.slotdeps = containers.UniqueNameTracker() diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 302d5cfb79a08b6adf52ebd44533152c5454eadc..969c90c78871ebff02b360f8f09623df56c9c077 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -34,7 +34,7 @@ except ImportError: # pylint: enable=g-import-not-at-top -class NumpyState(base.CheckpointableBase): +class NumpyState(base.Checkpointable): """A checkpointable object whose NumPy array attributes are saved/restored. Example usage: @@ -130,7 +130,7 @@ class NumpyState(base.CheckpointableBase): @six.add_metaclass(abc.ABCMeta) -class PythonStateWrapper(base.CheckpointableBase): +class PythonStateWrapper(base.Checkpointable): """Wraps a Python object for storage in an object-based checkpoint.""" @abc.abstractmethod diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py index 7e77453f3d848c2e321ed2ba66917a742d95459a..3e9700ad74618e24843181d169f3fb39ac96bff6 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -43,7 +43,7 @@ class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): return self._restore_callback(tensor) -class _SplitDependency(checkpointable.CheckpointableBase): +class _SplitDependency(checkpointable.Checkpointable): """Looks like a regular variable while synchronizing save/restores.""" def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index 00a805af25d5d0ea723db5d015fb12bf45c53857..664a4e76ab31bf31c7a57924e4af866f2d746804 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -44,7 +44,7 @@ def _combine_variable_closure(variable): return _consume_restore_buffer_fn -class SaveTensorSlicesAsDeps(base.CheckpointableBase): +class SaveTensorSlicesAsDeps(base.Checkpointable): def __init__(self): self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) @@ -59,14 +59,14 @@ class SaveTensorSlicesAsDeps(base.CheckpointableBase): self._track_checkpointable(dep, name=name) -class HasRegularDeps(tracking.Checkpointable): +class HasRegularDeps(tracking.AutoCheckpointable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) -class OnlyOneDep(tracking.Checkpointable): +class OnlyOneDep(tracking.AutoCheckpointable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 1311063ec023bdaa2588d6f1c826bf900f7dea09..20f8c2b2453a58fdbe5a3587fa6687debd9c06d3 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -27,7 +27,6 @@ tf_kernel_library( deps = [ ":bigquery_table_accessor", ":bigquery_table_partition_proto_cc", - "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:reader_base", @@ -79,7 +78,6 @@ tf_kernel_library( srcs = ["gcs_config_ops.cc"], visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/contrib/cloud:gcs_config_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/platform/cloud:curl_http_request", diff --git a/tensorflow/contrib/cluster_resolver/__init__.py b/tensorflow/contrib/cluster_resolver/__init__.py index 390b3e7550b3d991269bb84707c3500f2fa33290..a4dea85efd98893c881abbd3f7ebda78755b8189 100644 --- a/tensorflow/contrib/cluster_resolver/__init__.py +++ b/tensorflow/contrib/cluster_resolver/__init__.py @@ -23,7 +23,7 @@ from __future__ import print_function from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver -from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver @@ -36,7 +36,7 @@ _allowed_symbols = [ 'ClusterResolver', 'SimpleClusterResolver', 'UnionClusterResolver', - 'GceClusterResolver', + 'GCEClusterResolver', 'KubernetesClusterResolver', 'TFConfigClusterResolver', 'TPUClusterResolver', diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py index 10d93549ebbd4f7e900796d0516b0af1744224af..ef1e9f11a07a5be6c0b181f5e0b80e0e2214f972 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/__init__.py +++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py @@ -25,7 +25,7 @@ from __future__ import print_function from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver -from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver @@ -43,7 +43,7 @@ _allowed_symbols = [ 'ClusterResolver', 'SimpleClusterResolver', 'UnionClusterResolver', - 'GceClusterResolver', + 'GCEClusterResolver', 'KubernetesClusterResolver', 'TFConfigClusterResolver', 'TPUClusterResolver', diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index 55e61155c683c928efab9bb018868faec3e3df8c..5b49116ff6a4e17a774ea79b33ae1b948ba9f187 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Stub file for GceClusterResolver to maintain backwards compatibility.""" +"""Stub file for GCEClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division @@ -23,13 +23,14 @@ from __future__ import print_function # existing OSS code will not be broken. # pylint: disable=unused-import -from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented + _allowed_symbols = [ - 'GceClusterResolver', + 'GCEClusterResolver', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index a63366e1361effe20787c197eddd66b5c0c96410..2ad9ae42a16f690d38b8e2652e853012ec1dd267 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -3,16 +3,16 @@ cmake_minimum_required(VERSION 3.5) if(WIN32) if(${CMAKE_VERSION} VERSION_LESS "3.8") - message(WARNING "Your current cmake version is ${CMAKE_VERSION} which does not support setting the toolset architecture to x64. This may cause \"compiler out of heap space\" errors when building. Consider upgrading your cmake to > 3.8 and using the flag -Thost=x64 when running cmake.") + message(WARNING "Your current cmake version is ${CMAKE_VERSION} which does not support setting the toolset architecture to x64. This may cause \"compiler out of heap space\" errors when building. Consider upgrading your cmake to > 3.8 and using the flag -Thost=x64 when running cmake. Ignore this if you are on CMake GUI.") else() if(NOT CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE OR NOT "${CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE}" STREQUAL "x64") - message(WARNING "Your current cmake generator is set to use 32 bit toolset architecture. This may cause \"compiler out of heap space\" errors when building. Consider using the flag -Thost=x64 when running cmake.") + message(WARNING "Your current cmake generator is set to use 32 bit toolset architecture. This may cause \"compiler out of heap space\" errors when building. Consider using the flag -Thost=x64 when running cmake. Ignore this if you are on CMake GUI.") endif() endif() endif() # Project -project(tensorflow C CXX) +project(tensorflow VERSION 1.12.0 LANGUAGES C CXX) # Set C++14 as standard for the whole project set(CMAKE_CXX_STANDARD 14) @@ -52,11 +52,17 @@ option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for th option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) option(tensorflow_DISABLE_EIGEN_FORCEINLINE "Disable forceinline, to speed up build on windows." OFF) +if (WIN32) +SET(tensorflow_WIN_CPU_SIMD_OPTIONS "/arch:AVX" CACHE STRING "Enables CPU SIMD instructions") +SET_PROPERTY(CACHE tensorflow_WIN_CPU_SIMD_OPTIONS PROPERTY STRINGS /arch:AVX) +endif() + # SIMD, MKL and MKLDNN options option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions" OFF) option(tensorflow_ENABLE_MKL_SUPPORT "Enable Intel MKL support" OFF) option(tensorflow_ENABLE_MKLDNN_SUPPORT "Enable Intel MKLDNN support, requires MKL enabled" OFF) + # GPU, CUDA and cuDNN options option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) @@ -79,6 +85,11 @@ if (NOT WIN32) # option's default value is OFF. Fill it with real default values set(tensorflow_CUDNN_INCLUDE /usr/include) endif (NOT tensorflow_CUDNN_INCLUDE) + option(tensorflow_NCCL_INCLUDE "nccl.h header install path" /usr/include/) + if (NOT tensorflow_NCCL_INCLUDE) + # option's default value is OFF. Fill it with real default values + set(tensorflow_NCCL_INCLUDE /usr/include) + endif (NOT tensorflow_NCCL_INCLUDE) option(tensorflow_PATH_CUDNN_LIB "Override PATH_CUDA_LIB for cudnn" ${tensorflow_PATH_CUDA_LIB}) if (NOT tensorflow_PATH_CUDNN_LIB) # option's default value is OFF. Fill it with real default values @@ -193,6 +204,7 @@ if(WIN32) set(CMAKE_SUPPRESS_REGENERATION ON) endif() + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions -std=c++11") endif() @@ -281,6 +293,14 @@ else (systemlib_ZLIB) ${zlib_STATIC_LIBRARIES}) endif (systemlib_ZLIB) +if (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${abseil_cpp_LIBRARIES}) +else (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${abseil_cpp_STATIC_LIBRARIES}) +endif (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_DEPENDENCIES zlib_copy_headers_to_destination gif_copy_headers_to_destination @@ -378,8 +398,8 @@ if (tensorflow_ENABLE_GPU) list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") endif (NOT WIN32) - # minimum 9.1 in cuda version - find_package(CUDA 9.1 REQUIRED) + # minimum 9.0 in cuda version + find_package(CUDA 9.0 REQUIRED) if(NOT CUDA_FOUND) message(FATAL_ERROR "CUDA not found.") endif() @@ -394,6 +414,7 @@ if (tensorflow_ENABLE_GPU) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include) + include_directories(${CUDA_INCLUDE}) if (WIN32) add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.7,5.2,6.0,6.1,7.0) @@ -546,14 +567,20 @@ if (tensorflow_ENABLE_GPU) cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) - set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value - msvcp_dll_name=msvcp140.dll) + if(WIN32) + set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value + msvcp_dll_name=msvcp140.dll) + else() + set(tensorflow_BUILD_INFO_FLAGS --build_config cpu) + endif() endif(tensorflow_ENABLE_GPU) -# Find python executable -include(FindPythonInterp) -if(NOT ${PYTHONINTERP_FOUND}) - message(FATAL_ERROR "CMake was unable to find a python interpreter.") +if(tensorflow_BUILD_PYTHON_BINDINGS) + # Find python executable + include(FindPythonInterp) + if(NOT ${PYTHONINTERP_FOUND}) + message(FATAL_ERROR "CMake was unable to find a python interpreter.") + endif() endif() # Let's get to work! @@ -574,6 +601,7 @@ include(tf_cc_ops.cmake) include(tf_c.cmake) include(tf_grappler.cmake) include(tf_core_profiler.cmake) +include(tf_core_eager_runtime.cmake) if(tensorflow_BUILD_CC_EXAMPLE) include(tf_tutorials.cmake) include(tf_label_image_example.cmake) @@ -587,4 +615,4 @@ if(tensorflow_BUILD_SHARED_LIB) endif() if(tensorflow_BUILD_CC_TESTS OR tensorflow_BUILD_PYTHON_TESTS) include(tf_tests.cmake) -endif() +endif() \ No newline at end of file diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 84c679162c3ed8ffc9babcd3af583b26fb62c2d6..60ee1b4b3fd7d0b6afaefcc05effd3bbae00cf2c 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -5,10 +5,10 @@ CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all platforms. For details, see the [TensorFlow install guide](https://www.tensorflow.org/install/). -This directory contains CMake files for building TensorFlow on Microsoft -Windows. [CMake](https://cmake.org) is a cross-platform tool that can -generate build scripts for multiple build systems, including Microsoft -Visual Studio. +This directory contains CMake files for building TensorFlow on Microsoft Windows +and Linux. [CMake](https://cmake.org) is a cross-platform tool that can generate +build scripts for multiple build systems, including Microsoft Visual Studio and +GCC. "The method has not been tested on Mac OS X. **N.B.** We provide Linux build instructions primarily for the purpose of testing the build. We recommend using the standard Bazel-based build on @@ -17,12 +17,17 @@ Linux. Current Status -------------- -CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/source_windows) -for instructions on how to install a pre-built TensorFlow package on Windows. +CMake can be used to build TensorFlow on all platforms. See the +[getting started documentation](https://www.tensorflow.org/install/install_windows) +for instructions on how to install a pre-built TensorFlow package on Windows and +Linux. The procedure in MacOS is similar to the Linux build. ### Current known limitations -* It is not possible to load a custom Op library. -* GCS file system is not supported. + +* It is not possible to load a custom Op library. +* GCS file system is not supported. +* Debug build is not available since Python for Windows is no longer + distributed with a debug library. ## Building with CMake @@ -32,70 +37,88 @@ bindings. ### Prerequisites -* CMake version 3.5 or later. +* CMake version 3.5 or later. + +* [Git](https://git-scm.com) + +* [SWIG](http://www.swig.org/download.html) + +* [Perl](https://www.perl.org/get.html) (optional, for SSL support build) + +* [Go](https://golang.org/) (optional, for SSL support build) + +* [NASM](http://www.nasm.us/)/[YASM](http://yasm.tortall.net/) (optional, for + SSL support build) + +* Additional pre-requisites for Microsoft Windows: + + - Visual Studio 2015 (latest version of MSVC 2017 is not supported by CUDA + yet, try it on your own risk) -* [Git](https://git-scm.com) + - Python 3.5 -* [SWIG](http://www.swig.org/download.html) +* Additional prerequisites for Linux: -* Additional prerequisites for Microsoft Windows: - - Visual Studio 2015 - - Python 3.5 + - Python 2.7 or later + - [Docker](https://www.docker.com/) (for automated testing) -* Additional prerequisites for Linux: - - Python 2.7 or later - - [Docker](https://www.docker.com/) (for automated testing) +* Python dependencies: -* Python dependencies: - - wheel - - NumPy 1.11.0 or later + - wheel + - NumPy 1.11.0 or later ### Known-good configurations -* Microsoft Windows 10 - - Microsoft Visual Studio Enterprise 2015 with Visual C++ 2015 - - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.anaconda.com/download/) - - [Git for Windows version 2.9.2.windows.1](https://git-scm.com/download/win) - - [swigwin-3.0.10](http://www.swig.org/download.html) - - [NVidia CUDA Toolkit 8.0](https://developer.nvidia.com/cuda-downloads) - - [NVidia CUDNN 5.1](https://developer.nvidia.com/cudnn) - - [CMake 3.6](https://cmake.org/files/v3.6/cmake-3.6.3-win64-x64.msi) +* Microsoft Windows 10 -* Ubuntu 14.04 - - Makefile generator - - Docker 1.9.1 (for automated testing) + - Microsoft Visual Studio Enterprise/ Community 2015 with Visual C++ 2015 + - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.anaconda.com/download/) + - [Git for Windows version 2.9.2.windows.1](https://git-scm.com/download/win) + - [swigwin-3.0.10](http://www.swig.org/download.html) + - [NVidia CUDA Toolkit 9.0](https://developer.nvidia.com/cuda-downloads) + - [NVidia CUDNN 7](https://developer.nvidia.com/cudnn) + - [CMake 3.6](https://cmake.org/files/v3.6/cmake-3.6.3-win64-x64.msi) + +* Ubuntu 14.04 + + - Makefile generator + - Docker 1.9.1 (for automated testing) ### Current known limitations - - The Python package supports **Python 3.5 only**, because that is the only - version for which standard Python binaries exist and those binaries are - compatible with the TensorFlow runtime. (On Windows, the standard Python + +- The Python package supports **Python 3.5/3.6 only**, because these are the + only versions for which standard Python binaries exist and those binaries + are compatible with the TensorFlow runtime. (On Windows, the standard Python binaries for versions earlier than 3.5 were compiled with older compilers that do not have all of the features (e.g. C++11 support) needed to compile - TensorFlow. We welcome patches for making TensorFlow work with Python 2.7 - on Windows, but have not yet committed to supporting that configuration.) - - - The following Python APIs are not currently implemented: - * Loading custom op libraries via `tf.load_op_library()`. In order to use your - custom op, please put the source code under the tensorflow/core/user_ops - directory, and a shape function is required (not optional) for each op. - * Path manipulation functions (such as `tf.gfile.ListDirectory()`) are not - functional. - - - The `tf.contrib` libraries are not currently included in the PIP package. - - - The following operations are not currently implemented: - * `DepthwiseConv2dNative` - * `Digamma` - * `Erf` - * `Erfc` - * `Igamma` - * `Igammac` - * `ImmutableConst` - * `Lgamma` - * `Polygamma` - * `Zeta` - - - Google Cloud Storage support is not currently implemented. The GCS library + TensorFlow. We welcome patches for making TensorFlow work with Python 2.7 on + Windows, but have not yet committed to supporting that configuration.) + +- The following Python APIs are not currently implemented: + + * Loading custom op libraries via `tf.load_op_library()`. In order to use + your custom op, please put the source code under the + tensorflow/core/user_ops directory, and a shape function is required + (not optional) for each op. + * Path manipulation functions (such as `tf.gfile.ListDirectory()`) are not + functional. + +- The `tf.contrib` libraries are not currently included in the PIP package. + +- The following operations are not currently implemented: + + * `DepthwiseConv2dNative` + * `Digamma` + * `Erf` + * `Erfc` + * `Igamma` + * `Igammac` + * `ImmutableConst` + * `Lgamma` + * `Polygamma` + * `Zeta` + +- Google Cloud Storage support is not currently implemented. The GCS library currently depends on `libcurl` and `boringssl`, and the Windows version could use standard Windows APIs for making HTTP requests and cryptography (for OAuth). Contributions are welcome for this feature. @@ -104,9 +127,211 @@ We are actively working on improving CMake and Windows support, and addressing these limitations. We would appreciate pull requests that implement missing ops or APIs. +# CMake GUI build (all platforms) + +Install from CMake GUI would be a convenient way to generate C++ build projects. +The software supports Windows, MacOS and Linux, while the posix platform +provides an extra ccmake binary to run command line GUI. Both working principal +of cmake, ccmake and cmake-gui are the same, the only difference is by providing +suitable interface for project configuration and dependency setting. + +1. Pre-buid checklist: The following binary/libraries should be setted in + system path, otherwise you need to set manualy via cmake. + * Compiler (GCC for Linux, MSVC for Windows) + * Make sure compiler directory has been set to system path + * CUDA 9.0 (GPU build) + * CUDNN (GPU build) + * NCCL (GPU build on Linux) + * SWIG (python binding) + * Perl (required if you need ssl support, optional) + * Go (required if you need ssl support, optional) + * NASM/YASM (required by grpc for ssl support, optional) +2. Start CMake GUI +3. Click on `Browse Source` and direct to the folder + `/tensorflow/contrib/cmake` +4. Click on `Browse Build` and spectify a location that you want tensorflow to + be build +5. Click on `Configure`, a new window will be prompted out, specify the + generator mode for the project generation. For Windows, choose `Visual + Studio Win64`, for Linux, choose `Unix Makefiles`, then + press `Finish`. Wait for a moment, the default project dependency would + automatically generate. +6. There are a few options that you can customize your own build. **The setting + here is crucial for a successful build, please check all items carefully.** + + * `tensorflow_BUILD_ALL_KERNELS` should always be `on` + * `tensorflow_BUILD_CC_EXAMPLE` is default to be `on`. This can help you + to test build (optional) + * `tensorflow_BUILD_CONTRIB_KERNELS` is default to be `on`, but it won't + affect tensorflow function, turn it to `off` if you want a slim build. + (optional) + * `tensorflow_BUILD_PYTHON_BINDING` is default to be `on`. Set to `off` if + you don't need python interaface. If SWIG is not in system path, you + need set it manually. (optional) + * `tensorflow_BUILD_SHARED_LIB` is default to be `off`. Set to `on` if you + want the c++ interface. (optional) + * `tensorflow_ENABLE_GPU` is default to be `off`. Set to `on` if you want + GPU support. It will search CUDA and CUDNN dependecies if you have set + them to system path, otherwise CMake would prompt error and request you + to set it manually. (optional) + * `tensorflow_ENABLE_GRPC_SUPPORT` is default to be `on`. For Linux build, + this option must always be `on`. This need to be `on` for a gpu build. + Reminded that Perl, Go and NASM/YASM are required for this option if you + want to build grpc with offical SSL support. + * `tensorflow_ENABLE_POSITION_INDEPENDENT_CODE` should always be `on` + * `tensorflow_ENABLE_SNAPPY_SUPPORT` should always be `on` + * `tensorflow_OPTIMIZE_FOR_NATIVE_ARCH` should always be `on` + * `CMAKE_INSTALL_PREFIX` is the location where the final package will be + installed. You may change it to your own preferred path (optional) + +7. After changing the configuration in step 5, press `Configure` again + +8. If not error is found, press `Generate` + +#### Windows + +1. Open `tensorflow.sln` in the build folder (Windows). Change build type from + `Debug` to `Release`. Choose `Build`->`Build Solution`. This may take more + than hours of compilation. If everything is alright, the output window would + show no error. + + ##### Python + + In solution explorer, right click on `tf_python_build_pip_package` -> + `build`. It will generate the wheel file in + `/tf_python/dist`. Install with following command: + + `pip install --upgrade tensorflow-.whl` + + ***The wheel name varies depends on you config. Change to your own wheel + filename.*** + + Reminded that some pip installation requires administrator right command + prompt. + + ##### C++ + + You can directly use the build folder tree for C++ interface with cmake. If + you want to do installation for api releasing, right click on `Install` -> + `build`. The headers and library will be installed in the directory specify + by `CMAKE_INSTALL_PREFIX` during configuration. + +1. For smaller RAM computer, it is noticed that out of heap space error + appears. Change to command prompt build is an alternative to do step 1. + + Open `VS2015 x64 Native Tools Command Prompt`. You can open it by press + `Start`, then type the binary name. Use `VS2017 x64 Native Tools Command + Prompt` if you are using MSVC 2017. + + ##### Python + + Directly build python wheel package by following command: + + `MSBuild /p:Configuration=Release + ` + + Remember to change `` to the + actual path of the file, it can be found at the root of build directory + + Install the wheel file generated as instructed by step 1. + + ##### C++ interface + + Build from VS native toolchain with following command: `MSBuild + /p:Configuration=Release ` + + Headers are discretely located in the build folders. Tensorflow library can + be found at `/Release`, namely `tensorflow.dll` and + `tensorflow.lib`. + + * Build to install for api release (optional): `MSBuild + /p:Configuration=Release ` + + Remember to change `` and + `` to the actual path of the file, it can be found + at the root of build directory. + +#### Linux/MacOS (command line GNU build) + +1. Open the terminal, change working directory to the one specified in step 3. + +2. Type the following command: + + `make -sj all` + + ##### Python + + **Important Note** CMake generated python wheel for Linux/MacOs is currently + under development. Please use bazel build. + + Follow code is an expected Linux/MacOS python package build after + development work is completed. + + ``` + make -sj tf_python_build_pip_package + cd tf_python + pip install --upgrade tensorflow-.whl + ``` + + ##### C++ interface + + `make -sj install` + + Where `` is the threads used for the compilation, change + to any integer less or equal to your computer's maximum thread number. + + Headers are discretely located in the build folders. Tensorflow library can + be found at ``, namely `tensorflow.so` (Linux) or + `tensorflow.dylib` (MacOS). + +#### Start a Tensorflow C++ project with CMake + +Here we assume that you have basic knowledge on gathering dependency with +`CMakeLists.txt`. Here we introduce how the C++ api works with +[official hello world tutorial](https://www.tensorflow.org/api_guides/cc/guide). + +1. Create a new working directory and create a new text file named + `CMakeLists.txt` and the c++ file `main.cxx` +2. Fill in the `main.cxx` with the code provided in + [official c++ api basic](https://www.tensorflow.org/api_guides/cc/guide). +3. Fill in the `CMakeLists.txt` with following code: ``` cmake + cmake_minimum_required (VERSION 2.6) project (tf_hello) + + # Tensorflow + + find_package(Tensorflow REQUIRED) + include_directories(${TENSORFLOW_INCLUDE_DIRS}) + + # compiler setting required by tensorflow, to be tested on all compilers + + # currently only tested on MSVC and GCC + + if (${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) add_definitions(-DCOMPILER_MSVC) + elseif (${CMAKE_CXX_COMPILER_ID} STREQUAL GNU) if + (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS "3") + add_definitions(-DCOMPILER_GCC3) else() add_definitions(-D__GNUC__) endif() + else() message(ERROR " compiler ${CMAKE_CXX_COMPILER_ID} not supported by + this CMakeList.txt, under development") endif() + + add_executable(tf_hello main.cxx) target_link_libraries(tf_hello + ${TENSORFLOW_LIBRARIES}) ``` + +4. Configure the folder with cmake-gui, an error should be prompted out, + requesting you to locate the folder containing `TensorflowConfig.cmake`. + This file can be found at `` or `` (for + those have build install in previous steps). + +5. Configure again, generate the project. + +6. Compile the project with `Release` config (Windows). For Linux users, just + compile the project. + +7. Copy the `tensorflow.dll`(Windows)/`tensorflow.so`(Linux) from build + directory to the build folder containing `tf_hello` binary. + +8. Run `tf_hello` binary -Step-by-step Windows build -========================== +# Step-by-step Windows build (command prompt) 1. Install the prerequisites detailed above, and set up your environment. diff --git a/tensorflow/contrib/cmake/TensorflowConfig.cmake.in b/tensorflow/contrib/cmake/TensorflowConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..cc04db6e952f53b8bb5416dde60b8173e60bf60e --- /dev/null +++ b/tensorflow/contrib/cmake/TensorflowConfig.cmake.in @@ -0,0 +1,16 @@ +# - Config file for the Tensorflow package +# It defines the following variables +# TENSORFLOW_INCLUDE_DIRS - include directories for FooBar +# TENSORFLOW_LIBRARIES - libraries to link against + +# Compute paths +get_filename_component(TENSORFLOW_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +set(TENSORFLOW_INCLUDE_DIRS "@CONF_INCLUDE_DIRS@") + +# Our library dependencies (contains definitions for IMPORTED targets) +if(NOT TENSORFLOW_BINARY_DIR) + include("${TENSORFLOW_CMAKE_DIR}/TensorflowTargets.cmake") +endif() + +# These are IMPORTED targets created by TensorflowTargets.cmake +set(TENSORFLOW_LIBRARIES tensorflow) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in b/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..2a9609ddb9c4ca864651818bdfae0f8fe290de31 --- /dev/null +++ b/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in @@ -0,0 +1,11 @@ +set(PACKAGE_VERSION "@TENSORFLOW_VERSION@") + +# Check whether the requested PACKAGE_FIND_VERSION is compatible +if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_COMPATIBLE FALSE) +else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if ("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_EXACT TRUE) + endif() +endif() \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index 4546dbdecc0dbc36f17cc727345e0762718b5165..6c6a5df7f76723800740a81ccdcb137a0ec33846 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -39,21 +39,21 @@ else (systemlib_ABSEIL_CPP) include (ExternalProject) set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) - set(abseil_cpp_URL https://github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz) - set(abseil_cpp_HASH SHA256=84043ed402d2a2a6ba4cdddb7e85118b1158fd81fe4ac3a14adc343d054c1e2e) + set(abseil_cpp_URL https://github.com/abseil/abseil-cpp.git) + set(abseil_cpp_TAG master) set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp_build) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(abseil_cpp_STATIC_LIBRARIES ${abseil_cpp_BUILD}/absl/base/Release/absl_base.lib - ${abseil_cpp_BUILD}/absl/base/Release/absl_spinlock_wait.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib - ${abseil_cpp_BUILD}/absl/base/Release/absl_malloc_internal.lib - ${abseil_cpp_BUILD}/absl/base/Release/absl_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_malloc_internal.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_throw_delegate.lib ${abseil_cpp_BUILD}/absl/numeric/Release/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/Release/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/time/Release/absl_time.lib ${abseil_cpp_BUILD}/absl/types/Release/absl_bad_optional_access.lib) else() set(abseil_cpp_STATIC_LIBRARIES @@ -65,6 +65,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/numeric/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/time/absl_time.lib ${abseil_cpp_BUILD}/absl/types/absl_bad_optional_access.lib) endif() else() @@ -77,13 +78,13 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/numeric/libabsl_int128.a ${abseil_cpp_BUILD}/absl/strings/libabsl_strings.a ${abseil_cpp_BUILD}/absl/strings/libstr_format_internal.a + ${abseil_cpp_BUILD}/absl/time/libabsl_time.a ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) endif() ExternalProject_Add(abseil_cpp_build PREFIX abseil_cpp - URL ${abseil_cpp_URL} - URL_HASH ${abseil_cpp_HASH} + GIT_REPOSITORY ${abseil_cpp_URL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${abseil_cpp_STATIC_LIBRARIES} @@ -97,8 +98,10 @@ else (systemlib_ABSEIL_CPP) ) include_directories(${abseil_cpp_INCLUDE_DIR}) + message(STATUS ${abseil_cpp_INCLUDE_DIR}) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${abseil_cpp_STATIC_LIBRARIES}) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) -endif (systemlib_ABSEIL_CPP) +endif (systemlib_ABSEIL_CPP) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index b1e64aa55c80ad59cfdc0f4767c0282b4f73367f..e570c09ecb5e64130ed6f3375a51d74850cc3989 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) +set(GRPC_TAG 69b6c047bc767b4d80e7af4d00ccb7c45b683dae) if(WIN32) # We use unsecure gRPC because boringssl does not build on windows @@ -26,9 +26,9 @@ if(WIN32) set(grpc_SSL_PROVIDER NONE) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/gpr.lib) + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/grpc++_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/grpc_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/gpr.lib) else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/grpc++_unsecure.lib @@ -43,8 +43,9 @@ else() ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/zlib/libz.a) endif() add_definitions(-DGRPC_ARES=0) @@ -66,7 +67,7 @@ ExternalProject_Add(grpc -DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS} -DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} - -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} + -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} ) # grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h. diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 479609458c64f7c7bd7b3ce6b23aceaa3db17f21..b15143bfc1cd787b156c9d6dd724a17730f0f8fb 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 1.20.1) +set(nsync_TAG 1.20.2) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 1a147e9c8e5a9fee17a81e37c9babe3c9ec0290b..32e6d78e508e25f76bd263e9d52b6574ca315f6c 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -59,6 +59,7 @@ ExternalProject_Add(png -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${png_INSTALL} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} + -DPNG_TESTS:BOOL=OFF ) ## put png includes in the directory where they are expected diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index 56a57a2340ddc7f923c611c222a0399e279ad58a..773c37b309b1dff4ed28d24cd7d6140a63ec5bc6 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,18 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG v3.6.1) + +# enable choose protobuf versions +SET(PROTOBUF_VERSION "3.6.1" CACHE STRING "Protobuf version") +SET_PROPERTY(CACHE PROTOBUF_VERSION PROPERTY STRINGS "3.4.0" "3.5.0" "3.6.1") + +if(${PROTOBUF_VERSION} STREQUAL "3.5.1") + set(PROTOBUF_TAG v3.6.1) +elseif(${PROTOBUF_VERSION} STREQUAL "3.5.0") + set(PROTOBUF_TAG 2761122b810fe8861004ae785cc3ab39f384d342) +elseif(${PROTOBUF_VERSION} STREQUAL "3.4.0") + set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) +endif() if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake b/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake index d4f8bb1bec9ae8eff58dfe78168d8e71319c85e1..944ae3997a9489c13f65f93d9a7e61c21dd975c1 100644 --- a/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake +++ b/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake @@ -24,10 +24,10 @@ if(EXISTS "${ABSEIL_CPP_INCLUDE_DIR}" AND NOT "${ABSEIL_CPP_INCLUDE_DIR}" STREQU # search all libraries if no COMPONENTS was requested set(AbseilCpp_FIND_COMPONENTS "absl_algorithm;absl_any;absl_bad_any_cast" - "absl_bad_optional_access;absl_base absl_container;absl_debugging" + "absl_bad_optional_access;absl_base;absl_container;absl_debugging" "absl_dynamic_annotations;absl_examine_stack;absl_failure_signal_handler" - "absl_int128;absl_leak_check;absl_malloc_internal;absl_memory;absl_meta" - "absl_numeric;absl_optional;absl_span;absl_spinlock_wait;absl_stack_consumption" + "absl_int128;absl_leak_check;absl_internal_malloc_internal;absl_memory;absl_meta" + "absl_numeric;absl_optional;absl_span;absl_internal_spinlock_wait;absl_stack_consumption" "absl_stacktrace;absl_str_format;absl_strings;absl_symbolize;absl_synchronization" "absl_throw_delegate;absl_time;absl_utility;str_format_extension_internal" "str_format_internal;test_instance_tracker_lib") diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 96160568fa79291a7b391761373e1eaf0f70974e..21ae9a08a6bb8f71e5935ddde2d7bb3ed0cd8bbc 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -1,6 +1,9 @@ # python_sanity_test.py will complain about invalid or missing entries # problematic entries can be commented for temporary whitelisting tensorflow +tensorflow/compiler +tensorflow/compiler/xla +tensorflow/compiler/xla/service tensorflow/core tensorflow/core/example tensorflow/core/framework diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index 7a30eb94f54b18a2a517615a315e23e09e1170d0..a04142bd249ed5e16beba11057d0efc1e191e31b 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + ######################################################## # tf_c_framework library ######################################################## diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index 6c90cf398c69c8c1b22ea75e0c407f258e2535f9..6514ae50a4a35b35ba100af6997079294c22f9b8 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -149,11 +149,7 @@ add_library(tf_cc OBJECT ${tf_cc_srcs}) add_dependencies(tf_cc tf_cc_framework tf_cc_ops) if (WIN32) - if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow_internal.lib") - else() - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib") - endif() + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib") else (WIN32) set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX}") endif (WIN32) diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index a54cbff33b66d63d7229fa2f50b8a4ca962111ed..d8884d464fb5974d77506561a9ed36110a3804c0 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -39,6 +39,8 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/*test*.h" "${tensorflow_source_dir}/tensorflow/core/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/*main.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.h" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/*.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.cc" diff --git a/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake b/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake new file mode 100644 index 0000000000000000000000000000000000000000..78e4c0d3035cdaefa1d0950f4270d60152c805af --- /dev/null +++ b/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +######################################################## +# tf_core_eager_runtime library +######################################################## +file(GLOB_RECURSE tf_core_eager_runtime_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.h" +) + +file(GLOB_RECURSE tf_core_eager_runtime_exclude_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*test*.h" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*test*.cc" +) + +list(REMOVE_ITEM tf_core_eager_runtime_srcs ${tf_core_eager_runtime_exclude_srcs}) + +add_library(tf_core_eager_runtime OBJECT ${tf_core_eager_runtime_srcs}) +add_dependencies( + tf_core_eager_runtime + tf_c + tf_core_lib) + + +file(GLOB_RECURSE tf_c_eager_srcs + "${tensorflow_source_dir}/tensorflow/c/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/c/eager/*.h" +) + +file(GLOB_RECURSE tf_c_eager_exlclude_srcs + "${tensorflow_source_dir}/tensorflow/c/eager/*test*.h" + "${tensorflow_source_dir}/tensorflow/c/eager/*test*.cc" +) + +list(REMOVE_ITEM tf_c_eager_srcs ${tf_c_eager_exlclude_srcs}) + +add_library(tf_c_eager OBJECT ${tf_c_eager_srcs}) +add_dependencies( + tf_c_eager + tf_core_eager_runtime + tf_c + tf_cc_framework + tf_cc_while_loop + tf_core_lib + tf_protos_cc) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 7e806685b8448cbd629985cdc00ed1193857abe6..d8d1cc3aa2ca4fff3c950654b7cbd7085c76010c 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -140,6 +140,7 @@ set(tf_proto_text_srcs "tensorflow/core/example/example.proto" "tensorflow/core/example/feature.proto" "tensorflow/core/framework/allocation_description.proto" + "tensorflow/core/framework/api_def.proto" "tensorflow/core/framework/attr_value.proto" "tensorflow/core/framework/cost_graph.proto" "tensorflow/core/framework/device_attributes.proto" @@ -150,6 +151,7 @@ set(tf_proto_text_srcs "tensorflow/core/framework/log_memory.proto" "tensorflow/core/framework/node_def.proto" "tensorflow/core/framework/op_def.proto" + "tensorflow/core/framework/reader_base.proto" "tensorflow/core/framework/remote_fused_graph_execute_info.proto" "tensorflow/core/framework/resource_handle.proto" "tensorflow/core/framework/step_stats.proto" @@ -159,6 +161,7 @@ set(tf_proto_text_srcs "tensorflow/core/framework/tensor_shape.proto" "tensorflow/core/framework/tensor_slice.proto" "tensorflow/core/framework/types.proto" + "tensorflow/core/framework/variable.proto" "tensorflow/core/framework/versions.proto" "tensorflow/core/lib/core/error_codes.proto" "tensorflow/core/protobuf/cluster.proto" @@ -204,10 +207,10 @@ file(GLOB tf_core_platform_srcs "${tensorflow_source_dir}/tensorflow/core/framework/resource_handle.h" "${tensorflow_source_dir}/tensorflow/core/framework/resource_handle.cc") if (NOT tensorflow_ENABLE_GPU) - file(GLOB tf_core_platform_gpu_srcs + file(GLOB tf_core_platform_gpu_srcs_exclude "${tensorflow_source_dir}/tensorflow/core/platform/cuda_libdevice_path.*" "${tensorflow_source_dir}/tensorflow/core/platform/default/cuda_libdevice_path.*") - list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs}) + list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs_exclude}) else() file(GLOB tf_core_platform_srcs_exclude "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc") @@ -298,8 +301,8 @@ file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.h" + "${tensorflow_source_dir}/tensorflow/core/summary/*.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/*.h" "${tensorflow_source_dir}/public/*.h" ) @@ -313,14 +316,14 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/util/*test*.h" "${tensorflow_source_dir}/tensorflow/core/util/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/util/*main.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/loader.cc" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/vacuum.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/*test*.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/loader.cc" + "${tensorflow_source_dir}/tensorflow/core/summary/vacuum.cc" ) # TODO(jart): Why doesn't this work? # set_source_files_properties( -# ${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/snapfn.cc +# ${tensorflow_source_dir}/tensorflow/core/lib/db/snapfn.cc # PROPERTIES COMPILE_FLAGS -DSQLITE_OMIT_LOAD_EXTENSION) list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs}) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 9cfa8b90749280b6aa815cc210941c75bd5e16c5..310eed4ecbfdd30a3b3bdd4728c030fe70930797 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -13,13 +13,14 @@ # limitations under the License. # ============================================================================== set(tf_op_lib_names - "audio_ops" "array_ops" + "audio_ops" "batch_ops" "bitwise_ops" "boosted_trees_ops" "candidate_sampling_ops" "checkpoint_ops" + "collective_ops" "control_flow_ops" "ctc_ops" "cudnn_rnn_ops" @@ -27,13 +28,14 @@ set(tf_op_lib_names "dataset_ops" "decode_proto_ops" "encode_proto_ops" + "function_ops" "functional_ops" "image_ops" "io_ops" "linalg_ops" "list_ops" - "lookup_ops" "logging_ops" + "lookup_ops" "manip_ops" "math_ops" "nn_ops" @@ -43,10 +45,11 @@ set(tf_op_lib_names "remote_fused_graph_ops" "resource_variable_ops" "rpc_ops" + "scoped_allocator_ops" "script_ops" "sdca_ops" - "set_ops" "sendrecv_ops" + "set_ops" "sparse_ops" "spectral_ops" "state_ops" @@ -54,6 +57,7 @@ set(tf_op_lib_names "string_ops" "summary_ops" "training_ops" + "word2vec_ops" ) foreach(tf_op_lib_name ${tf_op_lib_names}) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index df7b854afcca1a0bed660624152f465d4bf3b25f..1fe8795ddf00232eba5a60a130e0845a6f6a8e17 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -313,15 +313,14 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) ${GENERATE_PYTHON_OP_LIB_DESTINATION} PARENT_SCOPE) endfunction() -GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("array_ops") +GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("batch_ops") GENERATE_PYTHON_OP_LIB("bitwise_ops") GENERATE_PYTHON_OP_LIB("boosted_trees_ops") -GENERATE_PYTHON_OP_LIB("math_ops") -GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("candidate_sampling_ops") GENERATE_PYTHON_OP_LIB("checkpoint_ops") +GENERATE_PYTHON_OP_LIB("collective_ops") GENERATE_PYTHON_OP_LIB("control_flow_ops" ADDITIONAL_LIBRARIES $) GENERATE_PYTHON_OP_LIB("ctc_ops") @@ -332,14 +331,18 @@ GENERATE_PYTHON_OP_LIB("decode_proto_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_decode_proto_op.py) GENERATE_PYTHON_OP_LIB("encode_proto_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_encode_proto_op.py) +GENERATE_PYTHON_OP_LIB("function_ops") +GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") GENERATE_PYTHON_OP_LIB("list_ops") GENERATE_PYTHON_OP_LIB("logging_ops") GENERATE_PYTHON_OP_LIB("lookup_ops") -GENERATE_PYTHON_OP_LIB("nn_ops") GENERATE_PYTHON_OP_LIB("manip_ops") +GENERATE_PYTHON_OP_LIB("math_ops") +GENERATE_PYTHON_OP_LIB("nn_ops") +GENERATE_PYTHON_OP_LIB("no_op") GENERATE_PYTHON_OP_LIB("parsing_ops") GENERATE_PYTHON_OP_LIB("random_ops") GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" @@ -347,17 +350,21 @@ GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" GENERATE_PYTHON_OP_LIB("resource_variable_ops") GENERATE_PYTHON_OP_LIB("rpc_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rpc/python/ops/gen_rpc_op.py) +GENERATE_PYTHON_OP_LIB("scoped_allocator_ops") GENERATE_PYTHON_OP_LIB("script_ops") GENERATE_PYTHON_OP_LIB("sdca_ops") +GENERATE_PYTHON_OP_LIB("sendrecv_ops") GENERATE_PYTHON_OP_LIB("set_ops") -GENERATE_PYTHON_OP_LIB("state_ops") GENERATE_PYTHON_OP_LIB("sparse_ops") GENERATE_PYTHON_OP_LIB("spectral_ops") +GENERATE_PYTHON_OP_LIB("state_ops") +GENERATE_PYTHON_OP_LIB("stateless_random_ops") GENERATE_PYTHON_OP_LIB("string_ops") GENERATE_PYTHON_OP_LIB("summary_ops") GENERATE_PYTHON_OP_LIB("user_ops") GENERATE_PYTHON_OP_LIB("training_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/training/gen_training_ops.py) +GENERATE_PYTHON_OP_LIB("word2vec_ops") GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_model_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_model_ops.py) @@ -391,11 +398,8 @@ GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py) GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) GENERATE_PYTHON_OP_LIB("contrib_periodic_resample_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/gen_periodic_resample_op.py) - GENERATE_PYTHON_OP_LIB("contrib_nearest_neighbor_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nearest_neighbor/ops/gen_nearest_neighbor_ops.py) GENERATE_PYTHON_OP_LIB("contrib_resampler_ops" @@ -420,8 +424,6 @@ GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) GENERATE_PYTHON_OP_LIB("contrib_gcs_config_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_gcs_config_ops.py) -GENERATE_PYTHON_OP_LIB("stateless_random_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py) @@ -524,11 +526,13 @@ if(WIN32) add_library(pywrap_tensorflow_internal_static STATIC ${pywrap_tensorflow_internal_src} $ + $ $ $ $ $ $ + $ $ $ $ @@ -581,11 +585,13 @@ endif(WIN32) add_library(pywrap_tensorflow_internal SHARED ${pywrap_tensorflow_internal_src} $ + $ $ $ $ $ $ + $ $ $ $ @@ -615,13 +621,28 @@ target_include_directories(pywrap_tensorflow_internal PUBLIC ${NUMPY_INCLUDE_DIR} ) -target_link_libraries(pywrap_tensorflow_internal PRIVATE +if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0) + # There is a bug in GCC 5 resulting in undefined reference to a __cpu_model function when + # linking to the tensorflow library. Adding the following libraries fixes it. + # See issue on github: https://github.com/tensorflow/tensorflow/issues/9593 + target_link_libraries(pywrap_tensorflow_internal PRIVATE ${tf_core_gpu_kernels_lib} ${tensorflow_EXTERNAL_LIBRARIES} tf_protos_cc tf_python_protos_cc ${PYTHON_LIBRARIES} + gcc_s + gcc ) +else() + target_link_libraries(pywrap_tensorflow_internal PRIVATE + ${tf_core_gpu_kernels_lib} + ${tensorflow_EXTERNAL_LIBRARIES} + tf_protos_cc + tf_python_protos_cc + ${PYTHON_LIBRARIES} +) +endif() if(WIN32) @@ -781,6 +802,7 @@ add_custom_command( # tensorflow/__init__.py depends on files generated in this step. So, remove it while # this step is running since the files aren't there yet. COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py # Run create_python_api.py to generate API init files. COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} @@ -806,10 +828,10 @@ add_dependencies(tf_python_api tf_python_ops) ######################################################## # Parse tensorflow/python/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_gen.bzl api_generator_BUILD_text) -STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text}) -string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) -string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_init_files.bzl api_generator_BUILD_text) +STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) +string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "," ";" api_init_files_list ${api_init_files_text}) set(api_init_files "") diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index fdf522f1fd90ffc64acbe82381ef57a389645d61..62005dd113bfb80fbdf23afb6d4aa5f90a1e32de 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -23,6 +23,8 @@ if(WIN32) # we need. # add_library(tensorflow_static STATIC + $ + $ $ $ $ @@ -65,6 +67,8 @@ endif(WIN32) # tensorflow is a shared library containing all of the # TensorFlow runtime and the standard ops and kernels. add_library(tensorflow SHARED + $ + $ $ $ $ @@ -96,6 +100,27 @@ if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0) target_link_libraries(tensorflow PRIVATE gcc_s gcc) endif() +# Offer the user the choice of overriding the installation directories +set(INSTALL_LIB_DIR lib CACHE PATH "Installation directory for libraries") +set(INSTALL_BIN_DIR bin CACHE PATH "Installation directory for executables") +set(INSTALL_INCLUDE_DIR include CACHE PATH + "Installation directory for header files") +if(WIN32 AND NOT CYGWIN) + set(DEF_INSTALL_CMAKE_DIR cmake) +else() + set(DEF_INSTALL_CMAKE_DIR lib/cmake) +endif() +set(INSTALL_CMAKE_DIR ${DEF_INSTALL_CMAKE_DIR} CACHE PATH + "Installation directory for CMake files") + +# Make relative paths absolute (needed later on) +foreach(p LIB BIN INCLUDE CMAKE) + set(var INSTALL_${p}_DIR) + if(NOT IS_ABSOLUTE "${${var}}") + set(${var} "${CMAKE_INSTALL_PREFIX}/${${var}}") + endif() +endforeach() + if(WIN32) add_dependencies(tensorflow tensorflow_static) endif(WIN32) @@ -103,14 +128,57 @@ endif(WIN32) target_include_directories(tensorflow PUBLIC $) -install(TARGETS tensorflow EXPORT tensorflow_export - RUNTIME DESTINATION bin - LIBRARY DESTINATION lib - ARCHIVE DESTINATION lib) +# Add all targets to build-tree export set +export(TARGETS tensorflow + FILE ${PROJECT_BINARY_DIR}/TensorflowTargets.cmake) + +# Export the package for use from the build-tree +export(PACKAGE Tensorflow) + +# Create the TensorflowConfig.cmake and TensorflowConfigVersion files +file(RELATIVE_PATH REL_INCLUDE_DIR "${INSTALL_CMAKE_DIR}" + "${INSTALL_INCLUDE_DIR}") +# for the build tree +set(CONF_INCLUDE_DIRS "${tensorflow_source_dir}" + "${PROJECT_BINARY_DIR}" + "${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src" + "${CMAKE_CURRENT_BINARY_DIR}/nsync/install/include" # Please if there is a better directory + "${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/Eigen/" + "${CMAKE_CURRENT_BINARY_DIR}/external/eigen_archive/" + "${tensorflow_source_dir}/third_party/eigen3/" + "${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/") +configure_file(TensorflowConfig.cmake.in + "${PROJECT_BINARY_DIR}/TensorflowConfig.cmake" @ONLY) +# for the install tree, yet to be complete +set(CONF_INCLUDE_DIRS "\${TENSORFLOW_CMAKE_DIR}/${REL_INCLUDE_DIR}") +configure_file(TensorflowConfig.cmake.in + "${PROJECT_BINARY_DIR}/${CMAKE_FILES_DIRECTORY}/TensorflowConfig.cmake" @ONLY) +# for both +configure_file(TensorflowConfigVersion.cmake.in + "${PROJECT_BINARY_DIR}/TensorflowConfigVersion.cmake" @ONLY) + +# install(TARGETS tensorflow EXPORT tensorflow_export +# RUNTIME DESTINATION ${INSTALL_BIN_DIR} +# LIBRARY DESTINATION ${INSTALL_LIB_DIR} +# ARCHIVE DESTINATION ${INSTALL_LIB_DIR}) + +# install(EXPORT tensorflow_export +# FILE TensorflowConfig.cmake +# DESTINATION ${INSTALL_CMAKE_DIR}) -install(EXPORT tensorflow_export - FILE TensorflowConfig.cmake - DESTINATION lib/cmake) +install(FILES + "${PROJECT_BINARY_DIR}/${CMAKE_FILES_DIRECTORY}/TensorflowConfig.cmake" + "${PROJECT_BINARY_DIR}/TensorflowConfigVersion.cmake" + DESTINATION "${INSTALL_CMAKE_DIR}" COMPONENT dev) + +# install the export set for use with the install-tree +install(EXPORT TensorflowTargets + DESTINATION ${INSTALL_CMAKE_DIR}) + +install(TARGETS tensorflow EXPORT TensorflowTargets + RUNTIME DESTINATION ${INSTALL_BIN_DIR} + LIBRARY DESTINATION ${INSTALL_LIB_DIR} + ARCHIVE DESTINATION ${INSTALL_LIB_DIR}) # install necessary headers # tensorflow headers @@ -145,6 +213,10 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ # unsupported Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ DESTINATION include/unsupported/Eigen) +# absl directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/abseil_cpp/src/abseil_cpp/absl/ + DESTINATION include/absl + FILES_MATCHING PATTERN "*.h") # mkl if (tensorflow_ENABLE_MKL_SUPPORT) install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/ diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index e4566437c60ebb2da039e61c171fbe954a7355c9..e32097ceddfec95b8677fc762d641d09078e5343 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -70,22 +70,30 @@ py_library( ], ) -tf_py_test( +cuda_py_test( name = "xla_test", srcs = ["xla_test.py"], additional_deps = [ ":xla", - "@six_archive//:six", + "@absl_py//absl/testing:parameterized", + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_util", "//tensorflow/python:math_ops", "//tensorflow/python:platform", - "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + ], + tags = [ + "no_mac", + "no_windows", ], - tags = ["no_pip"], + xla_enabled = True, ) diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index f867cd15b67dbd43650d8012b4299845af7200a8..238c6ab1366a50710efabea2f33eb1bd06fe9423 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import function_utils +from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -76,10 +77,22 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin All `Operation`s returned from `computation` will be executed when evaluating any of the returned output tensors. - inputs: A list of input tensors or `None` (equivalent to an empty list). + inputs: A list of inputs or `None` (equivalent to an empty list). Each input + can be a nested structure containing values that are convertible to + tensors. Note that passing an N-dimension list of compatible values will + result in a N-dimention list of scalar tensors rather than a single Rank-N + tensors. If you need different behavior, convert part of inputs to tensors + with `tf.convert_to_tensor`. Returns: - A list of output tensors. + Same data structure as if computation(*inputs) is called directly with some + exceptions for correctness. Exceptions include: + 1) None output: a NoOp would be returned which control-depends on + computation. + 2) Single value output: A tuple containing the value would be returned. + 3) Operation-only outputs: a NoOp would be returned which + control-depends on computation. + TODO(b/121383831): Investigate into removing these special cases. """ # pylint: disable=protected-access return _compile_internal(computation, inputs) @@ -131,6 +144,30 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext): logging.warning('... and %d more', len(self._unsupported_ops) - _MAX_WARNING_LINES) + def _RemoveExternalControlEdges(self, op): + """Remove any external control dependency on this op.""" + internal_control_inputs = [] + external_control_inputs = [] + for x in op.control_inputs: + # pylint: disable=protected-access + is_internal_op = False + ctxt = x._get_control_flow_context() + while ctxt is not None: + if ctxt == self: + is_internal_op = True + break + ctxt = ctxt._outer_context + if is_internal_op: + internal_control_inputs.append(x) + else: + external_control_inputs.append(x) + # pylint: enable=protected-access + # pylint: disable=protected-access + op._remove_all_control_inputs() + op._add_control_inputs(internal_control_inputs) + # pylint: enable=protected-access + return internal_control_inputs, external_control_inputs + def AddOp(self, op): """Create op in XLACompileContext and notifies outer context recursively.""" # pylint: disable=protected-access @@ -180,11 +217,14 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext): if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] + with ops.control_dependencies(None): + self.Enter() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] + self.Exit() # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access @@ -245,13 +285,21 @@ def _compile_internal(computation, inputs=None): Args: computation: A Python function that builds the computation to compile and execute. - inputs: A list of input tensors or `None` (equivalent to `[]`). Its order - should match ordering of computation arguments. + inputs: A list of inputs or `None` (equivalent to an empty list). Each input + can be a nested structure containing values that are convertible to + tensors. Note that passing an N-dimension list of compatible values will + result in a N-dimension list of scalar tensors rather than a single Rank-N + tensors. If you need different behavior, convert part of inputs to tensors + with `tf.convert_to_tensor`. + Returns: - A list of output tensors from computation. + Same data structure as if computation(*inputs) is called directly with some + exceptions for correctness. Exceptions include: 1) None output 2) Single + value output 3) Operation-only outputs Raises: ValueError: If any element in computation outputs is neither an operations or a value that can be converted to tensor. + ValueError: If computation outputs is non-flat and contains any Operations. TypeError: If `inputs` is not a list or tuple. """ if inputs is None: @@ -260,17 +308,10 @@ def _compile_internal(computation, inputs=None): if not isinstance(inputs, collections.Sequence): raise TypeError('inputs must be a list') + # Flatten inputs. + flat_inputs = nest.flatten(inputs) # Converts inputs to Tensors. - inputs = [ops.convert_to_tensor(x) for x in inputs] - input_arity = len(inputs) - - arg_error = check_function_argument_count( - computation, input_arity, infeed_queue=None) - if arg_error is not None: - raise TypeError( - 'Supplied computation cannot be called with the specified inputs. You ' - 'specified %d inputs: %s, but the computation needs %s' % - (input_arity, str([i.name for i in inputs]), arg_error)) + flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] cluster_name = ops.get_default_graph().unique_name('cluster') pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') @@ -280,11 +321,15 @@ def _compile_internal(computation, inputs=None): # Add identity ops so even unused inputs are 'consumed' by the # computation. - computation_inputs = [ + flat_inputs = [ array_ops.identity(x, name='input_{}'.format(i)) - for i, x in enumerate(inputs) + for i, x in enumerate(flat_inputs) ] + # Re-pack flat_inputs in same structure as 'inputs'. + computation_inputs = nest.pack_sequence_as( + structure=inputs, flat_sequence=flat_inputs) + # Only resource variables work inside an XLA computation, so turn on # resource variables for the computation. vscope = variable_scope.get_variable_scope() @@ -297,66 +342,166 @@ def _compile_internal(computation, inputs=None): # Restore variable scope after computation. vscope.set_use_resource(saved_use_resource) - # If the computation returns `None`, make it an empty tuple. - if outputs is None: - outputs = tuple() - # If the computation only returned one value, make it a tuple. - if not isinstance(outputs, collections.Sequence): - outputs = (outputs,) - - # Append `no_op` here so that return value of this function always contains - # at least one op that can trigger XlaLaunch node. - outputs += (control_flow_ops.no_op(),) - try: - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - except Exception as e: - raise ValueError( - 'XLA computation function return values must all either be Operations' - ' or convertible to Tensors. Got error: "%s"' % str(e)) - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - 'XLA computation function must return zero or more Tensor values ' - 'followed by zero or more Operations.') - output_arity = len(output_tensors) - - new_output_tensors = [] - for t in output_tensors: - with ops.device(t.device if t.device else ''): - new_output_tensors.append(array_ops.identity(t)) + outputs_is_flat = is_flat(outputs) + if outputs_is_flat: + output_tensors, control_deps = _postprocess_flat_outputs(outputs) + else: + output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) - output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() - outputs = [ - xla_ops.xla_cluster_output(output_tensors[i], name='output{}'.format(i)) - for i in xrange(output_arity) + # When XLA computation returns only operations and no tensors, a NoOp + # dependent on the operations in outputs is returned. Otherwise final + # outputs would be empty and there is no way to trigger returned + # operations. + if not output_tensors: + return control_flow_ops.group(control_deps, name='output_0') + + output_tensors = [ + xla_ops.xla_cluster_output(o, name='output{}'.format(i)) + for i, o in enumerate(output_tensors) ] - with ops.control_dependencies(output_operations): - if output_arity == 0: - # When XLA computation returns only operations and no tensors, a NoOp - # dependent on the operations in outputs is returned. Otherwise final - # outputs would be empty and there is no way to trigger returned - # operations. - return control_flow_ops.no_op(name='output_0') - else: - # Wraps the outputs in identity operators that carries control - # dependencies. - return [ - array_ops.identity(outputs[i], name='output_%d' % i) - for i in xrange(output_arity) - ] + with ops.control_dependencies(control_deps): + # Wraps the outputs in identity operators that carries control + # dependencies. + output_tensors = [ + array_ops.identity(o, name='output_%d' % i) + for i, o in enumerate(output_tensors) + ] + + # If `computation` returned non-flat output structure, pack output tensors + # back into same structure. + if not outputs_is_flat: + output_tensors = nest.pack_sequence_as( + structure=outputs, flat_sequence=output_tensors) + + return output_tensors + + +def is_flat(outputs): + """Checks if outputs is a flat structure. + + Following structures and values are considered flat: + 1) None + 2) A single object + 3) A list or tuple of Tensors/Operations + + The only structures that this function understands are sequences and + dictionaries. E.g. this means that if outputs contains a single + user-defined Object, it is considered to be flat. Errors are raised later on + if that Object cannot be converted to a Tensor. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + A boolean indicates whether outputs is flat. + """ + # If outputs is a list or tuple, check if it has any nested structure. If + # there is, then outputs is non-flat. + if isinstance(outputs, collections.Sequence): + for o in outputs: + if isinstance(o, collections.Sequence) or isinstance(o, dict): + return False + + # If outputs is a dict, it is non-flat. + if isinstance(outputs, dict): + return False + + # Getting here means either outputs itself is a single non-structured value + # or it is a flat list of single non-structured values. + return True + + +def _postprocess_flat_outputs(outputs): + """Validates flat outputs and adds back device assignments. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + Tensors and Operations extracted from outputs. + """ + # Following code segment is to preserve legacy behavior. Previously we only + # supported flat outputs and thus for consistency it was nice to convert even + # single element into a tuple. But now that we support arbitrary output + # structure, this is no longer necessary. + # TODO(b/121383831): Migrate all legacy use cases and delete this special + # case. + # If the computation returns `None`, make it an empty tuple. + if outputs is None: + outputs = tuple() + # If the computation only returned one value, make it a tuple. + if not isinstance(outputs, collections.Sequence): + outputs = (outputs,) + + # Append `no_op` here so that return value of this function always contains + # at least one op that can trigger XlaLaunch node. + outputs += (control_flow_ops.no_op(),) + try: + outputs = [ + o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + for o in outputs + ] + except Exception as e: + raise ValueError( + 'XLA computation function return values must all either be Operations' + ' or convertible to Tensors. Got error: "%s"' % str(e)) + + # Separates the returned Operations and Tensors. + output_operations = [o for o in outputs if isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] + + if outputs != output_tensors + output_operations: + raise ValueError( + 'XLA computation function must return zero or more Tensor values ' + 'followed by zero or more Operations.') + + new_output_tensors = [] + for t in output_tensors: + with ops.device(t.device if t.device else ''): + new_output_tensors.append(array_ops.identity(t)) + + return new_output_tensors, output_operations + + +def _postprocess_non_flat_outputs(outputs): + """Validates non-flat outputs and adds back device assignments. + + Args: + outputs: Output from `computation` inside `xla.compile`. + + Returns: + Tensors extracted from outputs and an empty list because Operations are not + allowed in non-flat outputs.. + """ + # Convert all non-Operation outputs to Tensors. + new_output_tensors = [] + for o in nest.flatten(outputs): + if isinstance(o, ops.Operation): + raise ValueError( + 'xla.compile does not support Operation as return value in non-flat ' + 'output structure. You can set returned Operations as control ' + 'dependencies of returned Tensors so Operations are triggered when ' + 'Tensors are evaluated. Operation found: "%s"' % o.name) + + try: + o = ops.convert_to_tensor(o) + except Exception as e: + raise ValueError( + 'XLA computation function return values must all either be ' + 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) + + # Makes sure even pass-through inputs/outputs are touched in compile + # context by creating an Identity node inside compile context. + with ops.device(o.device if o.device else ''): + new_output_tensors.append(array_ops.identity(o)) + + return new_output_tensors, [] @contextlib.contextmanager diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py index 3b49755afcf0753d31c0ce506dce42709b1ee8bc..c4384dcde75035dc55e67bd503e348fe19b97025 100644 --- a/tensorflow/contrib/compiler/xla_test.py +++ b/tensorflow/contrib/compiler/xla_test.py @@ -18,11 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re +from absl.testing import parameterized + from tensorflow.contrib.compiler import xla +from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.contrib.tpu.python.tpu import tpu_feed +from tensorflow.contrib.training.python.training import hparam from tensorflow.python import summary +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import logging_ops @@ -30,6 +38,14 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +from tensorflow.python.training import training + + +_TRAIN = model_fn_lib.ModeKeys.TRAIN +_EVAL = model_fn_lib.ModeKeys.EVAL +_EXPECTED_LOSS = 1 +_EXPECTED_FEATURE = 2 +_EXPECTED_LABEL = 3 class XLACompileContextTest(test.TestCase): @@ -252,5 +268,329 @@ class CheckFunctionArgumentCountTest(test.TestCase): xla.check_function_argument_count(func, 0, queue)) +def _test_train_model_fn(features, labels, mode, params): + """A dummy model_fn for testing purpose.""" + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=loss, train_op=array_ops.identity(loss)) + + +@xla.estimator_model_fn +def decorated_model_fn(features, labels, mode, params): + return _test_train_model_fn(features, labels, mode, params) + + +def make_dummy_features_labels(): + # XLA CPU/GPU backend doesn't support guaranteed constant, thus use dataset + # container to work around. + features_dataset = dataset_ops.Dataset.from_tensors( + constant_op.constant(_EXPECTED_FEATURE)).repeat(10) + features_op = features_dataset.make_one_shot_iterator().get_next() + labels_dataset = dataset_ops.Dataset.from_tensors( + constant_op.constant(_EXPECTED_LABEL)).repeat(10) + labels_op = labels_dataset.make_one_shot_iterator().get_next() + return features_op, labels_op + + +class XlaDecoratorTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ('test_use_as_decorator', decorated_model_fn, None), + ('test_use_as_function', xla.estimator_model_fn(_test_train_model_fn), + None), + ('test_use_tpu_false_hparams', decorated_model_fn, + hparam.HParams(use_tpu=False)), + ('test_use_tpu_false_dict_params', decorated_model_fn, { + 'use_tpu': False + }), + ) + def test_compile(self, model_fn, params): + """Calls model_fn and verifies it is compiled.""" + with test.mock.patch.object(xla, 'compile') as mock_xla_compile: + loss = constant_op.constant(_EXPECTED_LOSS) + mock_xla_compile.return_value = [loss] + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn( + features=features, labels=labels, mode=_TRAIN, params=params or {}) + + self.assertEqual(mock_xla_compile.call_count, 1) + self.assertEqual(estimator_spec.mode, _TRAIN) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss)) + self.assertEqual(sess.run(estimator_spec.train_op), sess.run(loss)) + + @parameterized.named_parameters( + ('test_use_tpu_true_hparams', decorated_model_fn, + hparam.HParams(use_tpu=True)), + ('test_use_tpu_true_dict_params', decorated_model_fn, { + 'use_tpu': True + }), + ) + def test_not_compile(self, model_fn, params): + """Calls model_fn and verifies it is NOT compiled.""" + with test.mock.patch.object(xla, 'compile') as mock_xla_compile: + loss = constant_op.constant(_EXPECTED_LOSS) + mock_xla_compile.return_value = [loss] + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn( + features=features, labels=labels, mode=_TRAIN, params=params or {}) + + mock_xla_compile.assert_not_called() + self.assertEqual(estimator_spec.mode, _TRAIN) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), sess.run(loss)) + self.assertEqual(sess.run(estimator_spec.train_op), sess.run(loss)) + + def test_model_with_summary(self): + """Tests that summary ops are disabled.""" + + @xla.estimator_model_fn + def model_fn_with_summary(features, labels, mode, params): + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + summary.scalar('loss_scalar_summary', loss) + summary.histogram('loss_histogram_summary', loss) + summary.image('loss_image_summary', loss) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=loss, train_op=array_ops.identity(loss)) + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn_with_summary( + features=features, labels=labels, mode=_TRAIN, params={}) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + +def _test_eval_metric_fn(eval_tensor_1, eval_tensor_2): + return { + 'metric_1': (eval_tensor_1, eval_tensor_1), + 'metric_2': (eval_tensor_2, eval_tensor_2), + } + + +class XlaDecoratorEvaluationTest(test.TestCase): + + def _verify_evaluation_result(self, eval_model_fn): + features, labels = make_dummy_features_labels() + estimator_spec = eval_model_fn( + features=features, labels=labels, mode=_EVAL, params={}) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_1'][0]), + _EXPECTED_FEATURE + _EXPECTED_LABEL) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_1'][1]), + _EXPECTED_FEATURE + _EXPECTED_LABEL) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_2'][0]), + _EXPECTED_FEATURE - _EXPECTED_LABEL) + self.assertEqual( + sess.run(estimator_spec.eval_metric_ops['metric_2'][1]), + _EXPECTED_FEATURE - _EXPECTED_LABEL) + + def test_eval_base_estimator_spec_eval_metric_ops_disallowed(self): + + @xla.estimator_model_fn + def eval_model_fn_return_estimator_spec(features, labels, mode, params): + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops={ + 'metric': (array_ops.identity(loss), control_flow_ops.no_op()) + }) + + with self.assertRaisesRegexp( + ValueError, 'EstimatorSpec.eval_metric_ops is not supported with XLA ' + 'compilation. Please use TPUEstimatorSpec.eval_metrics instead.'): + self._verify_evaluation_result(eval_model_fn_return_estimator_spec) + + def test_eval_base_estimator_spec_no_eval_metric_ops(self): + + @xla.estimator_model_fn + def eval_model_fn_no_eval_metric_ops(features, labels, mode, params): + del features, labels, params + return model_fn_lib.EstimatorSpec( + mode=mode, loss=constant_op.constant(_EXPECTED_LOSS)) + + features, labels = make_dummy_features_labels() + estimator_spec = eval_model_fn_no_eval_metric_ops( + features=features, labels=labels, mode=_EVAL, params={}) + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + def test_eval_no_eval_metrics(self): + + @xla.estimator_model_fn + def eval_model_fn_no_eval_metrics(features, labels, mode, params): + del features, labels, params + return tpu_estimator.TPUEstimatorSpec( + mode=mode, loss=constant_op.constant(_EXPECTED_LOSS)) + + features, labels = make_dummy_features_labels() + estimator_spec = eval_model_fn_no_eval_metrics( + features=features, labels=labels, mode=_EVAL, params={}) + + self.assertEqual(estimator_spec.eval_metric_ops, {}) + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + def test_eval_fn_missing_input_tensor(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors_dict = { + 'eval_tensor_1': features + labels, + } + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, + dummy_eval_metric_fn_tensors_dict)) + + with self.assertRaisesRegexp( + ValueError, + re.escape("Arguments ['eval_tensor_2'] are needed by metric_fn (first " + 'element of TPUEstimatorSpec.eval_metrics) but they are not ' + 'provided by evaluation tensors (second element of ' + 'TPUEstimatorSpec.eval_metrics).')): + self._verify_evaluation_result(eval_model_fn) + + def test_eval_fn_extraneous_input_tensor(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors_dict = { + 'eval_tensor_1': features + labels, + 'eval_tensor_2': features - labels, + 'extra_tensor': features * 2 - labels, + } + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, + dummy_eval_metric_fn_tensors_dict)) + + with self.assertRaisesRegexp( + ValueError, + re.escape("Arguments ['extra_tensor'] are provided by evaluation " + 'tensors (second element of TPUEstimatorSpec.eval_metrics) ' + 'but they are not needed by metric_fn (first element of ' + 'TPUEstimatorSpec.eval_metrics).')): + self._verify_evaluation_result(eval_model_fn) + + def test_eval_tensors_as_list(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors = [features + labels, features - labels] + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, dummy_eval_metric_fn_tensors)) + + self._verify_evaluation_result(eval_model_fn) + + def test_eval_tensors_as_dict(self): + + @xla.estimator_model_fn + def eval_model_fn(features, labels, mode, params): + del params + dummy_eval_metric_fn_tensors_dict = { + 'eval_tensor_1': features + labels, + 'eval_tensor_2': features - labels, + } + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + eval_metrics=(_test_eval_metric_fn, + dummy_eval_metric_fn_tensors_dict)) + + self._verify_evaluation_result(eval_model_fn) + + def test_model_with_summary(self): + """Tests that summary ops are disabled.""" + + @xla.estimator_model_fn + def model_fn_with_summary(features, labels, mode, params): + del features, labels, params + loss = constant_op.constant(_EXPECTED_LOSS) + summary.scalar('loss_scalar_summary', loss) + summary.histogram('loss_histogram_summary', loss) + summary.image('loss_image_summary', loss) + return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss) + + features, labels = make_dummy_features_labels() + estimator_spec = model_fn_with_summary( + features=features, labels=labels, mode=_EVAL, params={}) + + with self.test_session() as sess: + self.assertEqual(sess.run(estimator_spec.loss), _EXPECTED_LOSS) + + +class XlaDecoratorScaffoldTest(test.TestCase, parameterized.TestCase): + + def _make_scaffold_fn(self, mode): + + def _scaffold_fn_on_cpu(): + scaffold = training.Scaffold() + self.assertNotIn(mode, self.is_scaffold_fn_called) + self.is_scaffold_fn_called[mode] = True + return scaffold + + return _scaffold_fn_on_cpu + + def test_scaffold_fn_return_none(self): + + @xla.estimator_model_fn + def model_fn(features, labels, mode, params): + del features, labels, params + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + train_op=control_flow_ops.no_op(), + scaffold_fn=lambda: None) + + features, labels = make_dummy_features_labels() + with self.assertRaisesRegexp( + ValueError, + 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed'): + model_fn(features=features, labels=labels, mode=_TRAIN, params={}) + + @parameterized.named_parameters( + ('train_mode', _TRAIN), + ('eval_mode', _EVAL), + # TODO(ycao): Add predict_mode test after PREDICT mode is implemented. + ) + def test_scaffold_fn_in_mode(self, mode): + + @xla.estimator_model_fn + def model_fn(features, labels, mode, params): + del features, labels, params + return tpu_estimator.TPUEstimatorSpec( + mode=mode, + loss=constant_op.constant(_EXPECTED_LOSS), + train_op=control_flow_ops.no_op(), + scaffold_fn=self._make_scaffold_fn(mode)) + + features, labels = make_dummy_features_labels() + + self.is_scaffold_fn_called = {} + model_fn(features=features, labels=labels, mode=mode, params={}) + self.assertTrue(self.is_scaffold_fn_called[mode]) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md index cb1dd7d836ae11700b2ffaaff4fda5b7f943f87d..7ffb6894d37444fd78015b6c124c46f2855c1cde 100644 --- a/tensorflow/contrib/constrained_optimization/README.md +++ b/tensorflow/contrib/constrained_optimization/README.md @@ -1,5 +1,10 @@ +**NOTE: As tensorflow.contrib is being +[deprecated](https://github.com/tensorflow/community/pull/18), TFCO is moving to +its own repository on +[github](https://github.com/google-research/tensorflow_constrained_optimization).** + # ConstrainedOptimization (TFCO) TFCO is a library for optimizing inequality-constrained problems in TensorFlow. diff --git a/tensorflow/contrib/constrained_optimization/python/candidates_test.py b/tensorflow/contrib/constrained_optimization/python/candidates_test.py index a4c49d48bc5c763489215261a909573af0f19055..280e9acd88638a9385bfd9128ba6d3739879aab2 100644 --- a/tensorflow/contrib/constrained_optimization/python/candidates_test.py +++ b/tensorflow/contrib/constrained_optimization/python/candidates_test.py @@ -52,12 +52,12 @@ class CandidatesTest(test.TestCase): distribution = candidates.find_best_candidate_distribution( objective_vector, constraints_matrix) # Verify that the solution is a probability distribution. - self.assertTrue(np.all(distribution >= 0)) + self.assertTrue(np.all(distribution >= -1e-6)) self.assertAlmostEqual(np.sum(distribution), 1.0) # Verify that the solution satisfies the constraints. maximum_constraint_violation = np.amax( np.dot(constraints_matrix, distribution)) - self.assertLessEqual(maximum_constraint_violation, 0) + self.assertLessEqual(maximum_constraint_violation, 1e-6) # Verify that the solution matches that which we expect. expected_distribution = np.array([0.37872711, 0.62127289, 0, 0]) self.assertAllClose(expected_distribution, distribution, rtol=0, atol=1e-6) diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 656633f0bf21a4d46cb85547241ef0fd42807ed6..40e159b8fcbd1864284e208cb15d9ed96119f840 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -38,12 +38,12 @@ tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run( [unary_scores, sequence_lengths, transition_params, train_op]) for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores, tf_sequence_lengths): -# Remove padding. -tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] + # Remove padding. + tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] -# Compute the highest score and its tag sequence. -tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( - tf_unary_scores_, tf_transition_params) + # Compute the highest score and its tag sequence. + tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( + tf_unary_scores_, tf_transition_params) """ from __future__ import absolute_import diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index a268415f0e65206294431a537be18cadbe1a1e84..f5219eb134d07c09b16a544f71d4c18986c19681 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -68,6 +68,7 @@ def RunLSTM(sess, batch_size, time, num_layers=1, + variable_seq_lengths=False, is_training=True, dropout=0., num_dirs=True, @@ -99,6 +100,13 @@ def RunLSTM(sess, num_units).astype(dtype.as_numpy_dtype), dtype=dtype) + if variable_seq_lengths: + lengths_v = np.random.randint(low=1, high=time + 1, size=batch_size) + lengths_v[0] = time # make sure the max sequence has 'time' elems + lengths = ops.convert_to_tensor(lengths_v.astype(np.int32)) + else: + lengths = None + initializer = init_ops.random_uniform_initializer( -0.01, 0.01, dtype=dtype, seed=19980904) @@ -115,6 +123,7 @@ def RunLSTM(sess, outputs_op, state_tuple_op = rnn.dynamic_rnn( cell, inputs, + sequence_length=lengths, initial_state=rnn_cell_impl.LSTMStateTuple( h=initial_h_op, c=initial_c_op), dtype=dtype, @@ -133,6 +142,7 @@ def RunLSTM(sess, cu_initial_h_op, cu_initial_c_op, opaque_params, + sequence_lengths=lengths, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) @@ -325,12 +335,19 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtype, - rtol=2e-6, - atol=2e-6): + variable_seq_lengths, + rtol=3e-6, + atol=3e-6): with self.session(use_gpu=True) as sess: (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM( - sess, num_units, input_size, batch_size, time, num_layers) + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) for s, cu_s in zip(state_tuple, cu_state_tuple): @@ -341,20 +358,33 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol) self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_training(self, num_units, input_size, batch_size, time, num_layers): + def test_training(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") - self._test_training_helper(num_units, input_size, batch_size, time, - num_layers, dtypes.float32) + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float32, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -365,12 +395,17 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtypes.float16, rtol=5e-3, - atol=5e-4) + atol=5e-4, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_inference(self, num_units, input_size, batch_size, time, num_layers): + def test_inference(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -381,7 +416,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - is_training=False) + is_training=False, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs) # h @@ -389,11 +425,14 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): # c self.assertAllClose(state_tuple.c, cu_state_tuple.c) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -405,7 +444,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dtype=dtypes.float16) + dtype=dtypes.float16, + variable_seq_lengths=variable_seq_lengths) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) @@ -416,11 +456,14 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): self.assertAllClose( state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): """Validates that dropout does not affect Cudnn Rnn inference.""" if not context.context().num_gpus(): self.skipTest("No GPUs found") @@ -436,7 +479,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=0.) + dropout=0., + variable_seq_lengths=variable_seq_lengths) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -448,7 +492,8 @@ class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=1.) + dropout=1., + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(cu_outputs, cu_outputs2) # h @@ -464,6 +509,7 @@ def RunGRU(sess, time, num_layers=1, is_training=True, + variable_seq_lengths=False, dropout=0., num_dirs=True, dtype=dtypes.float32): @@ -489,6 +535,13 @@ def RunGRU(sess, num_units).astype(dtype.as_numpy_dtype), dtype=dtype) + if variable_seq_lengths: + lengths_v = np.random.randint(low=1, high=time + 1, size=batch_size) + lengths_v[0] = time # make sure the max sequence has 'time' elems + lengths = ops.convert_to_tensor(lengths_v.astype(np.int32)) + else: + lengths = None + initializer = init_ops.random_uniform_initializer( -0.01, 0.01, dtype=dtype, seed=19980904) with variable_scope.variable_scope("test", initializer=initializer): @@ -521,6 +574,7 @@ def RunGRU(sess, outputs_op, h_op = rnn.dynamic_rnn( cell, inputs, + sequence_length=lengths, initial_state=initial_h_op, dtype=dtype, time_major=True, @@ -533,12 +587,14 @@ def RunGRU(sess, num_layers, num_units, input_size) opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn( inputs, cu_initial_h_op, array_ops.zeros_like(cu_initial_h_op), # not used opaque_params, + sequence_lengths=lengths, dropout=dropout, is_training=is_training, rnn_mode=cudnn_rnn_ops.CUDNN_GRU) @@ -615,12 +671,19 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, dtype, - rtol=2e-6, - atol=2e-6): + variable_seq_lengths, + rtol=3e-6, + atol=3e-6): with self.session(use_gpu=True) as sess: - (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, - cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunGRU( - sess, num_units, input_size, batch_size, time, num_layers) + (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, cu_hgrad, + wgrad, bgrad, cu_wgrad, cu_bgrad) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) @@ -631,20 +694,33 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): for wg, cu_wg in zip(wgrad, cu_wgrad): self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_training(self, num_units, input_size, batch_size, time, num_layers): + def test_training(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") - self._test_training_helper(num_units, input_size, batch_size, time, - num_layers, dtypes.float32) + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float32, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_training_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") self._test_training_helper( @@ -655,12 +731,17 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): num_layers, dtypes.float16, rtol=5e-3, - atol=5e-4) + atol=5e-4, + variable_seq_lengths=variable_seq_lengths) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def test_inference(self, num_units, input_size, batch_size, time, num_layers): + def test_inference(self, num_units, input_size, batch_size, time, num_layers, + variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -671,15 +752,19 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): batch_size, time, num_layers, - is_training=False) + is_training=False, + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(outputs, cu_outputs) self.assertAllClose(h, cu_h) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_fp16(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): if not context.context().num_gpus(): self.skipTest("No GPUs found") with self.session(use_gpu=True) as sess: @@ -691,17 +776,21 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dtype=dtypes.float16) + dtype=dtypes.float16, + variable_seq_lengths=variable_seq_lengths) rtol, atol = 5e-3, 5e-4 self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) - @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @parameterized.named_parameters( + ExpandNamedTestCases(NAMED_RNN_TESTCASES, **{ + "variable_seq_lengths": [True, False], + })) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") def test_inference_with_dropout(self, num_units, input_size, batch_size, time, - num_layers): + num_layers, variable_seq_lengths): """Validates that dropout does not affect Cudnn Rnn inference.""" # Hand-picked dropouts are used below (0. and 1.) if not context.context().num_gpus(): @@ -717,7 +806,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=0.) + dropout=0., + variable_seq_lengths=variable_seq_lengths) with ops.Graph().as_default() as g: with self.session(use_gpu=True, graph=g) as sess: @@ -729,7 +819,8 @@ class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): time, num_layers, is_training=False, - dropout=1.) + dropout=1., + variable_seq_lengths=variable_seq_lengths) self.assertAllClose(cu_outputs, cu_outputs2) self.assertAllClose(cu_h[0], cu_h2[0]) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 7e1b4062ce435f3ab4216e90b4f5fcbab984c1dc..ca92c31236a7a3882415834eb32a994a120b6d2d 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -1023,7 +1023,7 @@ class CudnnRNNTestCompatibleRNNCells(test_util.TensorFlowTestCase): outputs_v, output_state_v = sess.run( [outputs, output_state], feed_dict={cell_inputs: inference_input}) - self.assertAllClose(cudnn_outputs_v, outputs_v, atol=2e-5, rtol=2e-5) + self.assertAllClose(cudnn_outputs_v, outputs_v, atol=1e-4, rtol=2e-4) (cudnn_output_h_v,) = cudnn_output_states_v self.assertAllClose(cudnn_output_h_v, output_state_v, atol=2e-5, rtol=2e-5) diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 8e25637ed91a1559b321ea96efbfaa2910f67158..86ad8ae8073714657c78badb1e0b4a6d8c8ed5f0 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -374,7 +374,11 @@ class _CudnnRNN(base_layer.Layer): "This cell does not yet support object-based saving. File a feature " "request if this limitation bothers you.") - def call(self, inputs, initial_state=None, training=True): + def call(self, + inputs, + initial_state=None, + sequence_lengths=None, + training=True): """Runs the forward step for the RNN model. Args: @@ -382,6 +386,9 @@ class _CudnnRNN(base_layer.Layer): initial_state: a tuple of tensor(s) of shape `[num_layers * num_dirs, batch_size, num_units]`. If not provided, use zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the + batch_size. If not provided, the same sequence length will be assumed. training: whether this operation will be used in training or inference. Returns: output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`. @@ -411,7 +418,7 @@ class _CudnnRNN(base_layer.Layer): # For model that doesn't take input_c, replace with a dummy tensor. c = array_ops.constant([], dtype=dtype) outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel, - training) + sequence_lengths, training) if self._rnn_mode == CUDNN_LSTM: return outputs, (output_h, output_c) else: @@ -475,7 +482,7 @@ class _CudnnRNN(base_layer.Layer): dropout=self._dropout, direction=self._direction) - def _forward(self, inputs, h, c, opaque_params, training): + def _forward(self, inputs, h, c, opaque_params, sequence_lengths, training): output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access inputs, h, @@ -483,6 +490,7 @@ class _CudnnRNN(base_layer.Layer): opaque_params, training, self._rnn_mode, + sequence_lengths=sequence_lengths, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 1ce29b42d52ff67477161278ed11016c2e73041d..f36e8d5022bc7e3f8268a161089153e5510dffc6 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -837,7 +837,7 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access assert len(biases) == len(weights) for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): - cell = checkpointable_lib.Checkpointable() + cell = checkpointable_lib.AutoCheckpointable() checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access cell.bias = bias cell.kernel = kernel @@ -955,6 +955,7 @@ def _cudnn_rnn(inputs, params, is_training, rnn_mode, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -972,6 +973,10 @@ def _cudnn_rnn(inputs, params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. @@ -1010,7 +1015,10 @@ def _cudnn_rnn(inputs, "seed2": seed2, "name": name } - if use_cudnn_v2 != "1": + if sequence_lengths is not None: + args["sequence_lengths"] = sequence_lengths + outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args) + elif use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) else: outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args) @@ -1022,6 +1030,7 @@ def cudnn_lstm(inputs, input_c, params, is_training, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1051,12 +1060,17 @@ def cudnn_lstm(inputs, dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. name: name of the operation. Returns: outputs, output_h, output_c """ return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM, - input_mode, direction, dropout, seed, name) + sequence_lengths, input_mode, direction, dropout, seed, + name) def _cudnn_rnn_no_input_c(inputs, @@ -1064,6 +1078,7 @@ def _cudnn_rnn_no_input_c(inputs, params, is_training, rnn_mode, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1079,6 +1094,10 @@ def _cudnn_rnn_no_input_c(inputs, params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be 'linear_input', 'skip_input' or 'auto_select'. @@ -1098,8 +1117,8 @@ def _cudnn_rnn_no_input_c(inputs, """ input_c = array_ops.constant([], dtype=input_h.dtype) outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params, - is_training, rnn_mode, input_mode, - direction, dropout, seed, name) + is_training, rnn_mode, sequence_lengths, + input_mode, direction, dropout, seed, name) return outputs, output_h @@ -1107,6 +1126,7 @@ def cudnn_gru(inputs, input_h, params, is_training, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1129,6 +1149,10 @@ def cudnn_gru(inputs, 'skip_input' is only allowed when input_size == num_units; 'auto_select' implies 'skip_input' when input_size == num_units; otherwise, it implies 'linear_input'. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. @@ -1139,7 +1163,8 @@ def cudnn_gru(inputs, outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU, - input_mode, direction, dropout, seed, name) + sequence_lengths, input_mode, direction, dropout, + seed, name) def cudnn_rnn_relu(inputs, @@ -1150,6 +1175,7 @@ def cudnn_rnn_relu(inputs, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., seed=0, + sequence_lengths=None, name=None): """Cudnn RNN Relu. @@ -1162,30 +1188,34 @@ def cudnn_rnn_relu(inputs, is_training: whether this operation will be used in training or inference input_mode: indicate whether there is a linear projection between the input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' + 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the op seed used for initializing dropout. See `tf.set_random_seed` - for behavior. + for behavior. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. If not + provided, the same sequence length will be assumed. name: name of the operation. + Returns: outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_RELU, input_mode, direction, dropout, - seed, name) + CUDNN_RNN_RELU, sequence_lengths, input_mode, + direction, dropout, seed, name) def cudnn_rnn_tanh(inputs, input_h, params, is_training, + sequence_lengths=None, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, dropout=0., @@ -1208,6 +1238,10 @@ def cudnn_rnn_tanh(inputs, 'skip_input' is only allowed when input_size == num_units; 'auto_select' implies 'skip_input' when input_size == num_units; otherwise, it implies 'linear_input'. + sequence_lengths: an int32 array representing the variable sequence lengths + in a batch. The size of the array has to equal the batch_size. Default to + None, in which case sequences in the batch are assumed to have the same + length, which is inferred from inputs. direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. @@ -1218,8 +1252,8 @@ def cudnn_rnn_tanh(inputs, outputs, output_h """ return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, - CUDNN_RNN_TANH, input_mode, direction, dropout, - seed, name) + CUDNN_RNN_TANH, sequence_lengths, input_mode, + direction, dropout, seed, name) def cudnn_rnn_opaque_params_to_canonical(rnn_mode, @@ -1497,7 +1531,13 @@ class _CudnnRNN(object): input_mode=self._input_mode, direction=self._direction) - def __call__(self, input_data, input_h, input_c, params, is_training=True): + def __call__(self, + input_data, + input_h, + input_c, + params, + is_training=True, + sequence_lengths=None): """Runs the forward step for the RNN model. Args: @@ -1509,6 +1549,10 @@ class _CudnnRNN(object): A Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the batch_size. + Default to None, in which case sequences in the batch are assumed to + have the same length, which is inferred from inputs. Returns: output: the output sequence. output_h: the final state for h. @@ -1521,6 +1565,7 @@ class _CudnnRNN(object): params, is_training, self._rnn_mode, + sequence_lengths=sequence_lengths, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, @@ -1615,7 +1660,13 @@ class CudnnLSTM(_CudnnRNN): dropout=dropout, seed=seed) - def __call__(self, input_data, input_h, input_c, params, is_training=True): + def __call__(self, + input_data, + input_h, + input_c, + params, + sequence_lengths=None, + is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: @@ -1626,6 +1677,10 @@ class CudnnLSTM(_CudnnRNN): input_c: the initial hidden state for c. A Tensor of the same shape as input_h. params: the parameter buffer created for this model. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the batch_size. + Default to None, in which case sequences in the batch are assumed to + have the same length, which is inferred from inputs. is_training: whether this operation will be used in training or inference. Returns: output: the output sequence. @@ -1633,7 +1688,12 @@ class CudnnLSTM(_CudnnRNN): output_c: the final state for c. """ output, output_h, output_c = super(CudnnLSTM, self).__call__( - input_data, input_h, input_c, params, is_training=is_training) + input_data, + input_h, + input_c, + params, + sequence_lengths=sequence_lengths, + is_training=is_training) return (output, output_h, output_c) @@ -1687,7 +1747,12 @@ class _CudnnRNNNoInputC(_CudnnRNN): dropout=dropout, seed=seed) - def __call__(self, input_data, input_h, params, is_training=True): + def __call__(self, + input_data, + input_h, + params, + sequence_lengths=None, + is_training=True): """Runs the forward step for the Cudnn LSTM model. Args: @@ -1696,6 +1761,10 @@ class _CudnnRNNNoInputC(_CudnnRNN): input_h: the initial hidden state for h. A Tensor of shape [num_layers, batch_size, num_units]. params: the parameter buffer created for this model. + sequence_lengths: an int32 array representing the variable sequence + lengths in a batch. The size of the array has to equal the batch_size. + Default to None, in which case sequences in the batch are assumed to + have the same length, which is inferred from inputs. is_training: whether this operation will be used in training or inference. Returns: output: the output sequence. @@ -1707,6 +1776,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): params, is_training, self._rnn_mode, + sequence_lengths=sequence_lengths, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py index 0456463a1928cf226010670b90a5d574579e0411..6c5f8c6b00975b3fba041271309a93cecd9f5057 100644 --- a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py @@ -46,7 +46,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -88,7 +88,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -115,9 +115,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -142,7 +141,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): tensor_shape.TensorShape((3, 4))) self.assertEqual(actual_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -184,7 +183,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -211,9 +210,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((None, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index d2a72272db159755ac2d741bcdbce9ec646d928e..b9840b1ff1a3df5a05db0e64f436637220f49f80 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -23,6 +23,7 @@ import shutil from tensorflow.contrib.data.python.ops import readers from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -48,7 +49,7 @@ class LMDBDatasetTest(test_base.DatasetTestBase): num_repeats = 2 dataset = readers.LMDBDataset(filenames).repeat(num_repeats) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index c5a786232252432481566e3cde23e9310df172cc..2527706709fae8e459aca3489324d4db3c784be6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -63,13 +63,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> # _SlideDataset(window_size, window_shift, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -127,13 +127,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> _SlideDataset(window_size, stride, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, stride=stride_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -173,12 +173,12 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): window_shift_t = array_ops.placeholder(dtypes.int64, shape=[]) window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count_t).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer with self.cached_session() as sess: @@ -204,9 +204,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -233,9 +233,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): values=array_ops.fill([math_ops.to_int32(i)], i), dense_shape=[i]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -265,11 +265,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(_sparse).apply( sliding.sliding_window_batch(window_size=4, window_shift=2)).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) init_op = iterator.initializer get_next = iterator.get_next() @@ -305,11 +304,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): yield [4.0, 5.0, 6.0] yield [7.0, 8.0, 9.0, 10.0] - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_generator( generator, dtypes.float32, output_shapes=[None]).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) next_element = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 34dc2379d0cb38f8f6962fa42efe21b793bc8d65..0fb406f1167053a128646c5c692986b0ce016f1e 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -188,8 +188,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:function", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/util:structure", ], ) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 4601376dff47e161962e92678883039c4b88bab7..c6bf5215c9406d03d2704e46903b3aa57e7e68d9 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -21,10 +21,9 @@ from tensorflow.python.data.experimental.ops import optimization from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.util import deprecation @@ -355,7 +354,7 @@ def read_batch_features(file_pattern, shuffle=randomize_input, num_epochs=num_epochs, shuffle_buffer_size=capacity) - iterator = dataset.make_one_shot_iterator() + iterator = dataset_ops.make_one_shot_iterator(dataset) outputs = iterator.get_next() return outputs @@ -379,37 +378,25 @@ class LMDBDataset(dataset_ops.DatasetSource): (key value) pairs sequentially. For example: ```python + tf.enable_eager_execution() + dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb") - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() + # Prints the (key, value) pairs inside a lmdb file. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break + for key, value in dataset: + print(key, value) ``` Args: filenames: A `tf.string` tensor containing one or more filenames. """ - super(LMDBDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") - - def _as_variant_tensor(self): - return gen_experimental_dataset_ops.experimental_lmdb_dataset( - self._filenames, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) - - @property - def output_classes(self): - return ops.Tensor, ops.Tensor - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + variant_tensor = gen_experimental_dataset_ops.experimental_lmdb_dataset( + self._filenames, **dataset_ops.flat_structure(self)) + super(LMDBDataset, self).__init__(variant_tensor) @property - def output_types(self): - return dtypes.string, dtypes.string + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index bcc383587c54bd89502313f9328bc06c49046a87..6708e01d08135a132b797e317cd2a241c3428f40 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -18,11 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.util import deprecation @@ -31,7 +30,6 @@ class _SlideDataset(dataset_ops.UnaryDataset): def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" - super(_SlideDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor( window_size, dtype=dtypes.int64, name="window_stride") @@ -40,29 +38,21 @@ class _SlideDataset(dataset_ops.UnaryDataset): self._window_shift = ops.convert_to_tensor( window_shift, dtype=dtypes.int64, name="window_shift") - def _as_variant_tensor(self): - return gen_dataset_ops.slide_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + input_structure = structure.convert_legacy_structure( + input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) + self._structure = input_structure._batch(None) # pylint: disable=protected-access + variant_tensor = ged_ops.experimental_sliding_window_dataset( + self._input_dataset._variant_tensor, # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, window_stride=self._window_stride, **dataset_ops.flat_structure(self)) + super(_SlideDataset, self).__init__(input_dataset, variant_tensor) @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - input_shapes = self._input_dataset.output_shapes - return nest.pack_sequence_as(input_shapes, [ - tensor_shape.vector(None).concatenate(s) - for s in nest.flatten(self._input_dataset.output_shapes) - ]) - - @property - def output_types(self): - return self._input_dataset.output_types + def _element_structure(self): + return self._structure @deprecation.deprecated_args( diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 8a8dc159ade6f2a4a9b5ec29055ea4848492b29f..dbcaf8185fb7a9d2bcf22376439c0ebd49accb1a 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -43,28 +43,19 @@ the workers. Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` with [tf.keras] (https://www.tensorflow.org/guide/keras). -Take a very simple model consisting of a single layer: +Let's define a simple input dataset for training this model. Note that currently we require using +[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) +with `DistributionStrategy`. ```python import tensorflow as tf from tensorflow import keras -inputs = tf.keras.layers.Input(shape=(1,)) -predictions = tf.keras.layers.Dense(1)(inputs) -model = tf.keras.models.Model(inputs=inputs, outputs=predictions) -``` - -Let's also define a simple input dataset for training this model. Note that currently we require using -[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) -with `DistributionStrategy`. - -```python features = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10) labels = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10) train_dataset = tf.data.Dataset.zip((features, labels)) ``` - To distribute this Keras model on multiple GPUs using `MirroredStrategy` we first instantiate a `MirroredStrategy` object. @@ -72,14 +63,17 @@ first instantiate a `MirroredStrategy` object. distribution = tf.contrib.distribute.MirroredStrategy() ``` -We then compile the Keras model and pass the `MirroredStrategy` object in the -`distribute` argument (apart from other usual arguments like `loss` and -`optimizer`). +Take a very simple model consisting of a single layer. We need to create and compile +the model under the distribution strategy scope. ```python -model.compile(loss='mean_squared_error', - optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2), - distribute=distribution) +with distribution.scope(): + inputs = tf.keras.layers.Input(shape=(1,)) + predictions = tf.keras.layers.Dense(1)(inputs) + model = tf.keras.models.Model(inputs=inputs, outputs=predictions) + + model.compile(loss='mean_squared_error', + optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2)) ``` To train the model we call Keras `fit` API using the input dataset that we diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 8ec73654e30e4967f318c558ba94301e84a206e4..59d76f5d1c817d7f2cc8ad285b9fb517fe994a81 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -30,12 +30,13 @@ from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * +from tensorflow.contrib.distribute.python.tpu_strategy import initialize_tpu_system from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.distribute.cross_device_ops import * from tensorflow.python.distribute.distribute_config import DistributeConfig from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server -from tensorflow.python.training.distribute import * -from tensorflow.python.training.distribution_strategy_context import * +from tensorflow.python.distribute.distribute_lib import * +from tensorflow.python.distribute.distribution_strategy_context import * from tensorflow.python.util.all_util import remove_undocumented @@ -58,11 +59,14 @@ _allowed_symbols = [ 'StandardSingleLossStep', 'ReplicaContext', 'TPUStrategy', + 'initialize_tpu_system', 'get_cross_replica_context', 'get_distribution_strategy', 'get_loss_reduction', 'get_replica_context', + 'get_strategy', 'has_distribution_strategy', + 'has_strategy', 'in_cross_replica_context', 'require_replica_context', 'run_standard_tensorflow_server', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 249258def3c4e52604b63764d8a7b5f238b45daa..509eb78128d062c7ea44730c2797b7c919cd0d69 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -1,5 +1,10 @@ # Implementation of a prototype TF distributed computation library. +load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") +load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + package( default_visibility = [ "//tensorflow:internal", @@ -10,11 +15,18 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -# TODO(priyag): Figure out testonly issues that are preventing us from -# including our tests in pip for now. +py_library( + name = "distribute_test_lib_pip", + visibility = ["//tensorflow:internal"], + deps = [ + ":combinations", + ":keras_correctness_test_lib", + ":keras_test_lib", + ":multi_worker_test_base", + ":single_loss_example", + ":strategy_test_lib", + ], +) cuda_py_test( name = "values_test", @@ -22,25 +34,36 @@ cuda_py_test( additional_deps = [ ":combinations", ":mirrored_strategy", - ":multi_worker_test_base", "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:device_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", ], - tags = [ - "no_pip", +) + +cuda_py_test( + name = "input_lib_test", + srcs = ["input_lib_test.py"], + additional_deps = [ + ":combinations", + ":mirrored_strategy", + ":multi_worker_test_base", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:values", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", ], ) @@ -50,8 +73,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:mirrored_strategy", - "//tensorflow/python/distribute:values", ], ) @@ -60,18 +83,10 @@ py_library( srcs = ["parameter_server_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":mirrored_strategy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:parameter_server_strategy", "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", ], ) @@ -104,7 +119,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -113,15 +127,7 @@ py_library( srcs = ["one_device_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python/distribute:distribute_lib", - "//tensorflow/python/distribute:reduce_util", - "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", - "@six_archive//:six", + "//tensorflow/python/distribute:one_device_strategy", ], ) @@ -130,28 +136,16 @@ py_library( srcs = ["collective_all_reduce_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":mirrored_strategy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:collective_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/distribute:cross_device_ops", - "//tensorflow/python/distribute:cross_device_utils", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/distribute:values", - "//tensorflow/python/eager:context", + "//tensorflow/python/distribute:collective_all_reduce_strategy", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", ], ) py_library( name = "strategy_test_lib", - testonly = 1, srcs = ["strategy_test_lib.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -164,20 +158,18 @@ py_library( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//third_party/py/numpy", ], ) py_library( name = "combinations", - testonly = 1, srcs = ["combinations.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ ":mirrored_strategy", ":one_device_strategy", + ":parameter_server_strategy", ":tpu_strategy", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/optimizer_v2:training", @@ -186,6 +178,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/eager:context", + "//tensorflow/python/keras/optimizer_v2", "@absl_py//absl/testing:parameterized", ], ) @@ -193,9 +186,6 @@ py_library( py_test( name = "combinations_test", srcs = ["combinations_test.py"], - tags = [ - "no_pip", - ], deps = [ ":combinations", "//tensorflow/python/eager:test", @@ -206,13 +196,10 @@ py_test( name = "one_device_strategy_test", srcs = ["one_device_strategy_test.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ - ":one_device_strategy", ":strategy_test_lib", "//tensorflow/python:framework_test_lib", + "//tensorflow/python/distribute:one_device_strategy", "//tensorflow/python/eager:test", ], ) @@ -242,18 +229,13 @@ cuda_py_test( tags = [ "guitar", "multi_and_single_gpu", - "no_pip", ], ) py_library( name = "multi_worker_test_base", - testonly = 1, srcs = ["multi_worker_test_base.py"], srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -288,6 +270,8 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/distribute:input_lib", + "//tensorflow/python/distribute:numpy_dataset", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:values", ], @@ -320,14 +304,16 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) -py_library( - name = "minimize_loss_test_lib", - testonly = 1, +distribute_py_test( + name = "minimize_loss_test", srcs = ["minimize_loss_test.py"], + main = "minimize_loss_test.py", + tags = [ + "multi_and_single_gpu", + ], deps = [ ":combinations", ":mirrored_strategy", @@ -347,18 +333,6 @@ py_library( ], ) -cuda_py_test( - name = "minimize_loss_test", - srcs = ["minimize_loss_test.py"], - additional_deps = [ - ":minimize_loss_test_lib", - ], - tags = [ - "multi_and_single_gpu", - "no_pip", - ], -) - cuda_py_test( name = "moving_averages_test", srcs = ["moving_averages_test.py"], @@ -372,9 +346,6 @@ cuda_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], - tags = [ - "no_pip", - ], ) cuda_py_test( @@ -392,7 +363,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -415,7 +385,6 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_oss", # http://b/119349471 - "no_pip", "tf_integration_test", ], ) @@ -426,10 +395,10 @@ cuda_py_test( additional_deps = [ ":keras_test_lib", ], + shard_count = 4, tags = [ "multi_and_single_gpu", "no_oss", # http://b/119349471 - "no_pip", "tf_integration_test", ], ) @@ -459,7 +428,6 @@ cuda_py_test( shard_count = 48, tags = [ "multi_and_single_gpu", - "no_pip", # TODO(b/118768923): Re-enable {a,m,t}san test. "noasan", "nomsan", @@ -481,10 +449,13 @@ py_library( ], ) -py_library( - name = "step_fn_test_lib", - testonly = 1, +distribute_py_test( + name = "step_fn_test", srcs = ["step_fn_test.py"], + main = "step_fn_test.py", + tags = [ + "multi_and_single_gpu", + ], deps = [ ":combinations", ":single_loss_example", @@ -497,18 +468,6 @@ py_library( ], ) -cuda_py_test( - name = "step_fn_test", - srcs = ["step_fn_test.py"], - additional_deps = [ - ":step_fn_test_lib", - ], - tags = [ - "multi_and_single_gpu", - "no_pip", - ], -) - py_library( name = "monitor", srcs = ["monitor.py"], @@ -525,10 +484,10 @@ cuda_py_test( additional_deps = [ ":combinations", ":monitor", - ":one_device_strategy", ":single_loss_example", "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", + "//tensorflow/python/distribute:one_device_strategy", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python:framework_ops", @@ -536,7 +495,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -553,9 +511,6 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], - tags = [ - "no_pip", - ], ) cuda_py_test( @@ -577,20 +532,22 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) py_library( name = "keras_test_lib", - testonly = 1, - srcs = ["keras_test.py"], + srcs = [ + "keras_backward_compat_test.py", + "keras_test.py", + ], deps = [ ":combinations", "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:client_testlib", "//tensorflow/python:training", + "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", "//third_party/py/numpy", @@ -598,46 +555,181 @@ py_library( ], ) -cuda_py_test( +distribute_py_test( name = "keras_test", srcs = ["keras_test.py"], - additional_deps = [ + full_precision = True, + main = "keras_test.py", + shard_count = 32, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ ":keras_test_lib", ], - shard_count = 16, +) + +# TODO(b/121200287): Remove this in 2.0 +distribute_py_test( + name = "keras_backward_compat_test", + srcs = ["keras_backward_compat_test.py"], + full_precision = True, + main = "keras_backward_compat_test.py", + shard_count = 31, tags = [ "multi_and_single_gpu", "no_oss", # TODO(b/117919883): Fix python error. - "no_pip", "no_windows_gpu", "notsan", ], + deps = [ + ":keras_test_lib", + ], ) py_library( - name = "metrics_v1_test_lib", - testonly = 1, - srcs = ["metrics_v1_test.py"], + name = "keras_correctness_test_lib", + srcs = [ + "keras_correctness_test_base.py", + "keras_dnn_correctness_test.py", + "keras_embedding_model_correctness_test.py", + "keras_image_model_correctness_test.py", + "keras_lstm_model_correctness_test.py", + "keras_stateful_lstm_model_correctness_test.py", + ], deps = [ ":combinations", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/contrib/distribute/python:tpu_strategy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/keras", + "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], ) -cuda_py_test( - name = "metrics_v1_test", - srcs = ["metrics_v1_test.py"], - additional_deps = [ - ":metrics_v1_test_lib", +distribute_py_test( + name = "keras_dnn_correctness_test", + size = "medium", + srcs = ["keras_dnn_correctness_test.py"], + full_precision = True, + main = "keras_dnn_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 19, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", ], +) + +distribute_py_test( + name = "keras_image_model_correctness_test", + size = "medium", + srcs = ["keras_image_model_correctness_test.py"], + full_precision = True, + main = "keras_image_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, tags = [ "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "keras_embedding_model_correctness_test", + size = "medium", + srcs = ["keras_embedding_model_correctness_test.py"], + full_precision = True, + main = "keras_embedding_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "keras_lstm_model_correctness_test", + size = "medium", + srcs = ["keras_lstm_model_correctness_test.py"], + full_precision = True, + main = "keras_lstm_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "keras_stateful_lstm_model_correctness_test", + size = "medium", + srcs = ["keras_stateful_lstm_model_correctness_test.py"], + full_precision = True, + main = "keras_stateful_lstm_model_correctness_test.py", + # Shard count is set to an odd number to distribute tasks across + # shards more evenly. + shard_count = 31, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. "no_pip", + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_correctness_test_lib", + ], +) + +distribute_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + main = "metrics_v1_test.py", + tags = [ + "multi_and_single_gpu", + ], + deps = [ + ":combinations", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", ], ) @@ -655,7 +747,6 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", ], ) @@ -666,7 +757,6 @@ cuda_py_test( additional_deps = [ ":combinations", "//tensorflow/python:client_testlib", - "//tensorflow/python:checkpoint_utils_test", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -674,6 +764,25 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", - "no_pip", + ], +) + +tf_xla_py_test( + name = "checkpointing_test", + srcs = ["checkpointing_test.py"], + disabled_backends = [ + # Only makes sense on TPUs + "cpu", + "gpu", + "cpu_ondemand", + ], + tags = [ + "no_oss", + ], + deps = [ + ":tpu_strategy", + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/python/eager:test", + "//tensorflow/python/training/checkpointable:util", ], ) diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index 31bd0e996a247a2fc01405fb3b8172a40853d698..7ee50f03155636a487020d0a9178107a06775588 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -25,6 +25,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations @@ -33,7 +34,23 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import checkpoint_utils_test +from tensorflow.python.training import saver as saver_lib + + +def _create_checkpoints(sess, checkpoint_dir): + checkpoint_prefix = os.path.join(checkpoint_dir, "model") + checkpoint_state_name = "checkpoint" + v1 = variable_scope.get_variable("var1", [1, 10]) + v2 = variable_scope.get_variable("var2", [10, 10]) + sess.run(variables.global_variables_initializer()) + v1_value, v2_value = sess.run([v1, v2]) + saver = saver_lib.Saver() + saver.save( + sess, + checkpoint_prefix, + global_step=0, + latest_filename=checkpoint_state_name) + return v1_value, v2_value class CheckpointUtilsWithDistributionStrategyTest( @@ -51,8 +68,7 @@ class CheckpointUtilsWithDistributionStrategyTest( def testInitFromCheckpoint(self, distribution, in_replica_mode): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: - v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints( - session, checkpoint_dir) + v1_value, v2_value = _create_checkpoints(session, checkpoint_dir) def init_and_verify(g): v1 = variable_scope.get_variable("new_var1", [1, 10]) @@ -71,7 +87,7 @@ class CheckpointUtilsWithDistributionStrategyTest( with ops.Graph().as_default() as g, distribution.scope(): if in_replica_mode: - distribution.call_for_each_replica(init_and_verify, args=[g]) + distribution.extended.call_for_each_replica(init_and_verify, args=[g]) else: init_and_verify(g) diff --git a/tensorflow/contrib/distribute/python/checkpointing_test.py b/tensorflow/contrib/distribute/python/checkpointing_test.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5b9f57b8a5bc12ee94399ec1fc5a55177a5b5d --- /dev/null +++ b/tensorflow/contrib/distribute/python/checkpointing_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +from tensorflow.compiler.tests import xla_test +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core +from tensorflow.python.platform import test +from tensorflow.python.training import adam as adam_v1 +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util as checkpointable_utils + + +class NonLayerCheckpointable(tracking.AutoCheckpointable): + + def __init__(self): + super(NonLayerCheckpointable, self).__init__() + self.a_variable = checkpointable_utils.add_variable( + self, name="a_variable", shape=[]) + + +class Subclassed(training.Model): + """A concrete Model for testing.""" + + def __init__(self): + super(Subclassed, self).__init__() + self._named_dense = core.Dense(1, use_bias=True) + self._second = core.Dense(1, use_bias=False) + # We can still track Checkpointables which aren't Layers. + self._non_layer = NonLayerCheckpointable() + + def call(self, values): + ret = self._second(self._named_dense(values)) + return ret + + +class TrainingCheckpointTests(xla_test.XLATestCase): + + def testEagerTPUDistributionStrategy(self): + self.skipTest("b/121387144") + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + def _train_fn(optimizer, model): + input_value = constant_op.constant([[3.]]) + optimizer.minimize( + functools.partial(model, input_value), + global_step=root.optimizer_step) + + for training_continuation in range(3): + strategy = tpu_strategy.TPUStrategy() + with strategy.scope(): + model = Subclassed() + optimizer = adam_v1.AdamOptimizer(0.001) + root = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, + optimizer_step=training_util.get_or_create_global_step()) + root.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) + + for _ in range(num_training_steps): + strategy.extended.call_for_each_replica( + functools.partial(_train_fn, optimizer, model)) + root.save(file_prefix=checkpoint_prefix) + self.assertEqual((training_continuation + 1) * num_training_steps, + root.optimizer_step.numpy()) + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 6e9f9facd0a209146d1ad8d101f0b8c41d77752a..19741627980c34d8c281f7aed6f1464d4a03393e 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -18,27 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib -from tensorflow.python.distribute import cross_device_utils -from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.distribute import values -from tensorflow.python.eager import context -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import collective_ops -from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver # TODO(yuefengz): support in-graph replication. class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): """Distribution strategy that uses collective ops for all-reduce. + *** contrib version *** + It is similar to the MirroredStrategy but it uses collective ops for reduction. @@ -61,280 +52,19 @@ class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): CollectiveAllReduceExtended(self, num_gpus_per_worker)) -class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): +class CollectiveAllReduceExtended( + collective_all_reduce_strategy.CollectiveAllReduceExtended): """Implementation of CollectiveAllReduceStrategy.""" def __init__(self, container_strategy, num_gpus_per_worker): - distribute_lib.DistributionStrategyExtended.__init__( - self, container_strategy) - self._num_gpus_per_worker = num_gpus_per_worker - self._initialize_local_worker(container_strategy, num_gpus_per_worker) - - def _initialize_local_worker(self, container_strategy, num_gpus_per_worker): - """Initializes the object for local training.""" - self._is_chief = True - self._num_workers = 1 - - if num_gpus_per_worker: - local_devices = [ - "/device:GPU:%d" % i for i in range(num_gpus_per_worker) - ] - else: - local_devices = ["/device:CPU:0"] - self._worker_device = device_util.canonicalize("/device:CPU:0") - - self._collective_keys = cross_device_utils.CollectiveKeys() - super(CollectiveAllReduceExtended, self).__init__( - container_strategy, - devices=local_devices, - cross_device_ops=cross_device_ops_lib.CollectiveAllReduce( - num_workers=1, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys)) - - self._cluster_spec = None - self._task_type = None - self._task_id = None - - logging.info("CollectiveAllReduceStrategy with local_devices = %r", - local_devices) - - def _initialize_multi_worker(self, container_strategy, num_gpus_per_worker, - cluster_spec, task_type, task_id): - """Initializes the object for multi-worker training.""" - if task_type is None or task_id is None: - raise ValueError("When `cluster_spec` is given, you must also specify " - "`task_type` and `task_id`") - if task_type not in ["chief", "worker"]: - raise ValueError( - "Unrecognized task_type: %r, valid task types are: \"chief\", " - "\"worker\"." % task_type) - cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) - if not self._num_workers: - raise ValueError("No `worker` or `chief` tasks can be found in " - "`cluster_spec`.") - - self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, - task_id) - - self._worker_device = "/job:%s/task:%d" % (task_type, task_id) - if num_gpus_per_worker: - local_devices = [ - "%s/device:GPU:%d" % (self._worker_device, i) - for i in range(num_gpus_per_worker) - ] - else: - local_devices = [self._worker_device] - - self._collective_keys = cross_device_utils.CollectiveKeys() + # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change + # the constructor's interface to allow customized cluster resolver. Use + # SimpleClusterResolver to override num_accelerators. + tfconfig = TFConfigClusterResolver() + cluster_resolver = SimpleClusterResolver( + cluster_spec=tfconfig.cluster_spec(), + task_type=tfconfig.task_type, + task_id=tfconfig.task_id, + num_accelerators=num_gpus_per_worker) super(CollectiveAllReduceExtended, self).__init__( - container_strategy, - devices=local_devices, - cross_device_ops=cross_device_ops_lib.CollectiveAllReduce( - num_workers=self._num_workers, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys)) - - # Add a default device so that ops without specified devices will not end up - # on other workers. - self._default_device = "/job:%s/task:%d" % (task_type, task_id) - - self._cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._task_type = task_type - self._task_id = task_id - - logging.info( - "Multi-worker CollectiveAllReduceStrategy with " - "cluster_spec = %r, task_type = %r, task_id = %r, " - "num_workers = %r, local_devices = %r", cluster_spec.as_dict(), - task_type, task_id, self._num_workers, local_devices) - - def _create_variable(self, next_creator, *args, **kwargs): - colocate_with = kwargs.pop("colocate_with", None) - devices = self._get_devices_from(colocate_with) - group_size = len(devices) * self._num_workers - group_key = self._collective_keys.get_group_key(self._devices) - - def _real_mirrored_creator(devices, *args, **kwargs): - """Creates one MirroredVariable on the current worker.""" - index = {} - unique_var_name = ops.get_default_graph().unique_name( - kwargs["name"], mark_as_used=False).rstrip("/") - collective_instance_key = self._collective_keys.get_instance_key( - key_id=unique_var_name) - if "initial_value" not in kwargs: - raise ValueError("Initial value must be specified.") - initial_value = kwargs["initial_value"] - if callable(initial_value): - initial_value_fn = initial_value - else: - initial_value_fn = lambda: initial_value - - for i, d in enumerate(devices): - with ops.device(d): - if i > 0: - # Give replicas meaningful distinct names: - var0name = index[devices[0]].name.split(":")[0] - # We append a / to variable names created on replicas with id > 0 to - # ensure that we ignore the name scope and instead use the given - # name as the absolute name of the variable. - kwargs["name"] = "%s/replica_%d/" % (var0name, i) - - # The initial value fn makes sure variables all initialized to - # same values. The first device of the chief worker will send their - # variable values to other devices and other workers. - def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring - with ops.device(device): - initial_value = initial_value_fn() - assert not callable(initial_value) - initial_value = ops.convert_to_tensor(initial_value) - - if self._is_chief and index == 0: - bcast_send = collective_ops.broadcast_send( - initial_value, initial_value.shape, initial_value.dtype, - group_size, group_key, collective_instance_key) - with ops.control_dependencies([bcast_send]): - return array_ops.identity(initial_value) - else: - return collective_ops.broadcast_recv( - initial_value.shape, initial_value.dtype, group_size, - group_key, collective_instance_key) - - kwargs["initial_value"] = _overridden_initial_value_fn - - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - v = next_creator(*args, **kwargs) - - if i == 0: - actual_var_name = v.name.split(":")[0] - assert unique_var_name == actual_var_name, "%r vs %r" % ( - unique_var_name, actual_var_name) - assert not isinstance(v, values.DistributedVariable) - index[d] = v - return index - - # pylint: disable=protected-access - return mirrored_strategy._create_mirrored_variable( - devices, _real_mirrored_creator, *args, **kwargs) - - def _distribute_dataset(self, dataset_fn): - """Distributes the dataset to each local GPU.""" - # TODO(yuefengz): shard the dataset. - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._devices, True) - - def _make_dataset_iterator(self, dataset): - worker_device_pairs = [(self._worker_device, self._devices)] - return values.DatasetIterator(dataset, worker_device_pairs, - self._num_replicas_in_sync) - - def _make_input_fn_iterator( - self, - input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - """Distributes the dataset to each local GPU.""" - if self._cluster_spec is None: - input_pipeline_id = 0 - else: - input_pipeline_id = multi_worker_util.id_in_cluster( - self._cluster_spec, self._task_type, self._task_id) - input_context = distribute_lib.InputContext( - num_input_pipelines=self._num_workers, - input_pipeline_id=input_pipeline_id, - num_replicas_in_sync=self._num_replicas_in_sync) - - return values.InputFunctionIterator( - input_fn, [(self._worker_device, self._devices)], [input_context]) - - def _configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - """Configures the object. - - Args: - session_config: a `tf.ConfigProto` - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type, such as "worker". - task_id: the current task id. - - Raises: - ValueError: if `task_type` is not in the `cluster_spec`. - """ - if not self._cluster_spec and cluster_spec: - # If a `cluster_spec` is already passed in, do nothing here. - # TODO(yuefengz): check `cluster_spec` is the same if this object has - # already been initialized with a `cluster_spec`. - self._initialize_multi_worker( - self._container_strategy(), self._num_gpus_per_worker, cluster_spec, - task_type, task_id) - - if session_config: - session_config.CopyFrom(self._update_config_proto(session_config)) - - def _update_config_proto(self, config_proto): - updated_config = copy.deepcopy(config_proto) - # Enable the scoped allocator optimization for CollectiveOps. This - # optimization converts many small all-reduces into fewer larger - # all-reduces. - rewrite_options = updated_config.graph_options.rewrite_options - rewrite_options.scoped_allocator_optimization = ( - rewriter_config_pb2.RewriterConfig.ON) - # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = - # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we - # clear and then append. - del rewrite_options.scoped_allocator_opts.enable_op[:] - rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") - - if not self._cluster_spec: - return updated_config - - assert self._task_type - assert self._task_id is not None - - # Collective group leader is needed for collective ops to coordinate - # workers. - if "chief" in self._cluster_spec.jobs: - updated_config.experimental.collective_group_leader = ( - "/job:chief/replica:0/task:0") - else: - if "worker" not in self._cluster_spec.jobs: - raise ValueError( - "You must have `chief` or `worker` jobs in the `cluster_spec`.") - updated_config.experimental.collective_group_leader = ( - "/job:worker/replica:0/task:0") - - # The device filters prevent communication between workers. - del updated_config.device_filters[:] - updated_config.device_filters.append( - "/job:%s/task:%d" % (self._task_type, self._task_id)) - - return updated_config - - @property - def experimental_between_graph(self): - return True - - @property - def experimental_should_init(self): - return True - - @property - def should_checkpoint(self): - return self._is_chief - - @property - def should_save_summary(self): - return self._is_chief - - @property - def _num_replicas_in_sync(self): - return len(self._devices) * self._num_workers - - # TODO(priyag): Delete this once all strategies use global batch size. - @property - def _global_batch_size(self): - return False + container_strategy, cluster_resolver=cluster_resolver) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index eba3585a55375ee1db561a459e079256c53a85cc..acbe4677b401cbea4fd0ec415415f25c920e68e4 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -29,9 +29,13 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import collective_all_reduce_strategy as core_collective_all_reduce_strategy from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -49,6 +53,55 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.training import adam from tensorflow.python.training import training_util +from tensorflow.python.training.server_lib import ClusterSpec + + +class MockCollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): + """Mock the strategy to allow cluster resolver as an argument.""" + + def __init__(self, cluster_resolver): + super(MockCollectiveAllReduceStrategy, self).__init__( + core_collective_all_reduce_strategy.CollectiveAllReduceExtended( + self, cluster_resolver=cluster_resolver)) + + +def create_test_objects(cluster_spec=None, + task_type=None, + task_id=None, + num_gpus=None, + use_core_strategy=False): + sess_config = config_pb2.ConfigProto() + if num_gpus is None: + num_gpus = context.num_gpus() + if use_core_strategy: + if cluster_spec and task_type and task_id is not None: + cluster_resolver = SimpleClusterResolver( + cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), + task_type=task_type, + task_id=task_id, + num_accelerators=num_gpus) + target = 'grpc://' + cluster_spec[task_type][task_id] + else: + cluster_resolver = SimpleClusterResolver( + ClusterSpec({}), num_accelerators=num_gpus) + target = '' + + strategy = MockCollectiveAllReduceStrategy(cluster_resolver) + sess_config = strategy.update_config_proto(sess_config) + else: + strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + if task_type and task_id is not None: + strategy.configure( + session_config=sess_config, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id) + target = 'grpc://' + cluster_spec[task_type][task_id] + else: + target = '' + + return strategy, target, sess_config class CollectiveAllReduceStrategyTestBase( @@ -64,16 +117,18 @@ class CollectiveAllReduceStrategyTestBase( CollectiveAllReduceStrategyTestBase.collective_key_base += 100000 super(CollectiveAllReduceStrategyTestBase, self).setUp() - def _get_test_object(self, task_type, task_id, num_gpus=0): - distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=num_gpus) - session_config = config_pb2.ConfigProto() - if task_type and task_id is not None: - distribution.configure( - session_config=session_config, - cluster_spec=self._cluster_spec, - task_type=task_type, - task_id=task_id) + def _get_test_object(self, + task_type, + task_id, + num_gpus=0, + use_core_strategy=False): + strategy, target, session_config = create_test_objects( + cluster_spec=self._cluster_spec, + task_type=task_type, + task_id=task_id, + num_gpus=num_gpus, + use_core_strategy=use_core_strategy) + collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + CollectiveAllReduceStrategyTestBase.collective_key_base, @@ -81,15 +136,16 @@ class CollectiveAllReduceStrategyTestBase( CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_with_id_start=num_gpus * 10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) - distribution.extended._collective_keys = collective_keys - distribution.extended._cross_device_ops._collective_keys = collective_keys - if task_type and task_id is not None: - return distribution, 'grpc://' + self._cluster_spec[task_type][ - task_id], session_config - else: - return distribution, '', session_config + strategy.extended._collective_keys = collective_keys + strategy.extended._cross_device_ops._collective_keys = (collective_keys) - def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + return strategy, target, session_config + + def _test_minimize_loss_graph(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, config = self._get_test_object(task_type, task_id, num_gpus) with ops.Graph().as_default(), \ @@ -122,20 +178,20 @@ class CollectiveAllReduceStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=[one]) + g_v = d.extended.call_for_each_replica(grad_fn, args=[one]) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( - d.update(v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + d.extended.update(v, update, args=(g,), group=False)): + after_list.append(d.extended.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -157,7 +213,11 @@ class CollectiveAllReduceStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before - def _test_complex_model(self, task_type, task_id, num_gpus): + def _test_complex_model(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, config = self._get_test_object(task_type, task_id, num_gpus) @@ -191,6 +251,7 @@ class CollectiveAllReduceStrategyTestBase( image = random_ops.random_uniform([2, 28, 28]) label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32) logits = model(image, training=True) + # TODO(yuefengz): make loss a callable for eager mode. loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits) optimizer = adam.AdamOptimizer(learning_rate=1e-4) train_op = optimizer.minimize(loss, @@ -201,14 +262,18 @@ class CollectiveAllReduceStrategyTestBase( self.cached_session(config=config, target=master_target) as sess: with d.scope(): - train_op = d.call_for_each_replica(model_fn) + train_op = d.extended.call_for_each_replica(model_fn) train_op = d.group(d.unwrap(train_op)) sess.run(variables.global_variables_initializer()) sess.run(train_op) return True - def _test_variable_initialization(self, task_type, task_id, num_gpus): + def _test_variable_initialization(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): distribution, master_target, config = self._get_test_object( task_type, task_id, num_gpus) with ops.Graph().as_default(), \ @@ -224,7 +289,7 @@ class CollectiveAllReduceStrategyTestBase( 1.0, 10.0, dtype=dtypes.float32)) return array_ops.identity(x) - x = distribution.call_for_each_replica(model_fn) + x = distribution.extended.call_for_each_replica(model_fn) reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x) x = distribution.unwrap(x)[0] @@ -237,8 +302,14 @@ class CollectiveAllReduceStrategyTestBase( reduced_x_value))) return np.allclose(x_value, reduced_x_value, atol=1e-5) - def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, - expected_values): + def _test_input_fn_iterator(self, + task_type, + task_id, + num_gpus, + input_fn, + expected_values, + test_reinitialize=True, + use_core_strategy=False): distribution, master_target, config = self._get_test_object( task_type, task_id, num_gpus) devices = distribution.extended.worker_devices @@ -251,22 +322,24 @@ class CollectiveAllReduceStrategyTestBase( for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - sess.run([values.select_device(d, next_element) for d in devices]) + sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. - sess.run(iterator.initialize()) + if test_reinitialize: + sess.run(iterator.initialize()) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) + self.assertEqual(expected_value, computed_value) class DistributedCollectiveAllReduceStrategyTest( @@ -280,71 +353,114 @@ class DistributedCollectiveAllReduceStrategyTest( cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=0) - def test_num_replicas_in_sync(self): - distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=2) - distribution.configure(cluster_spec=self._cluster_spec, task_type='worker', - task_id=0) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def test_num_replicas_in_sync(self, use_core_strategy): + distribution, _, _ = create_test_objects( + cluster_spec=self._cluster_spec, + task_type='worker', + task_id=0, + num_gpus=2, + use_core_strategy=use_core_strategy) num_workers = len(self._cluster_spec.get('chief', []) + self._cluster_spec.get('worker', [])) self.assertEqual(2 * num_workers, distribution.num_replicas_in_sync) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testMinimizeLossGraph(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testMinimizeLossGraph(self, num_gpus, use_core_strategy): + self._run_between_graph_clients( + self._test_minimize_loss_graph, + self._cluster_spec, + num_gpus, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testVariableInitialization(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testVariableInitialization(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_variable_initialization, self._cluster_spec, - num_gpus=num_gpus) + num_gpus=num_gpus, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testComplexModel(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_core_strategy=[True, False])) + def testComplexModel(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') self._run_between_graph_clients( - self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) + self._test_complex_model, + self._cluster_spec, + num_gpus=num_gpus, + use_core_strategy=use_core_strategy) # TODO(yuefengz): Update how we use num_gpus and required_gpus @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) - def testMakeInputFnIterator(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[0, 1, 2], + required_gpus=1, + use_dataset=[True, False], + use_core_strategy=[True, False])) + def testMakeInputFnIterator(self, num_gpus, use_dataset, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(100) + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(100) + else: + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next # We use CPU as the device when num_gpus = 0 devices_per_worker = max(1, num_gpus) expected_values = [[i+j for j in range(devices_per_worker)] for i in range(0, 100, devices_per_worker)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=3*devices_per_worker, expected_num_input_pipelines=3, expected_input_pipeline_id=1) # because task_id = 1 - self._test_input_fn_iterator('worker', 1, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + 'worker', + 1, + num_gpus, + input_fn, + expected_values, + test_reinitialize=use_dataset, + use_core_strategy=use_core_strategy) - def testUpdateConfigProto(self): - distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=2) - distribution.configure( - cluster_spec=self._cluster_spec, task_type='worker', task_id=1) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testUpdateConfigProto(self, use_core_strategy): + strategy, _, _ = self._get_test_object( + task_type='worker', + task_id=1, + num_gpus=2, + use_core_strategy=use_core_strategy) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) rewrite_options = config_proto.graph_options.rewrite_options rewrite_options.scoped_allocator_opts.enable_op.append('to_be_removed') - new_config = distribution.update_config_proto(config_proto) + new_config = strategy.update_config_proto(config_proto) # Verify group leader self.assertEqual('/job:worker/replica:0/task:0', @@ -395,36 +511,136 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) -class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, - strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class LocalCollectiveAllReduceStrategy( + CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): - def testMinimizeLossGraph(self, num_gpus=2): + @combinations.generate( + combinations.combine( + mode=['graph', 'eager'], + num_gpus=[2, 4], + required_gpus=2, + use_core_strategy=[True, False])) + def testMinimizeLoss(self, num_gpus, use_core_strategy): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - self._test_minimize_loss_graph(None, None, num_gpus) + if context.executing_eagerly(): + strategy, _, _ = self._get_test_object( + None, None, num_gpus, use_core_strategy=use_core_strategy) + self._test_minimize_loss_eager(strategy) + else: + self._test_minimize_loss_graph( + None, None, num_gpus, use_core_strategy=use_core_strategy) - def testComplexModel(self, num_gpus=2): - # Collective ops doesn't support strategy with one device. + @combinations.generate( + combinations.combine( + mode=['graph'], + num_gpus=[2, 4], + required_gpus=2, + use_core_strategy=[True, False])) + def testComplexModel(self, num_gpus, use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - self._test_complex_model(None, None, num_gpus) + self._test_complex_model( + None, None, num_gpus, use_core_strategy=use_core_strategy) - def testMakeInputFnIterator(self, num_gpus=2): - # Collective ops doesn't support strategy with one device. - if context.num_gpus() < num_gpus: - self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(10) - expected_values = [[i, i+1] for i in range(0, 10, 2)] + @combinations.generate( + combinations.combine( + mode=['graph', 'eager'], + required_gpus=2, + use_dataset=[True, False], + use_core_strategy=[True, False])) + def testMakeInputFnIterator(self, use_dataset, use_core_strategy): + num_gpus = 2 + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(5 * num_gpus) + else: + def fn(): + dataset = dataset_ops.Dataset.range(5 * num_gpus) + it = dataset.make_one_shot_iterator() + return it.get_next + expected_values = [range(i, i + num_gpus) for i in range(0, 10, num_gpus)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=1, expected_input_pipeline_id=0) - self._test_input_fn_iterator(None, None, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + None, + None, + num_gpus, + input_fn, + expected_values, + test_reinitialize=use_dataset, + use_core_strategy=use_core_strategy) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceSum(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceSumGradients(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum_gradients(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceSumGradientTape(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_sum_gradient_tape(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceMean(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceMeanGradients(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean_gradients(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testAllReduceMeanGradientTape(self, use_core_strategy): + if context.num_gpus() < 2: self.skipTest('Not enough GPUs') + distribution, target, config = self._get_test_object( + None, None, num_gpus=2, use_core_strategy=use_core_strategy) + with self.cached_session(config=config, target=target): + self._test_all_reduce_mean_gradient_tape(distribution) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testNumpyIterator(self, use_core_strategy): + num_gpus = 2 + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + strategy, _, _ = self._get_test_object( + None, None, num_gpus=num_gpus, use_core_strategy=use_core_strategy) + self._test_numpy_iterator(strategy) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index c5ce29a43632918be555db865891fdbb5d22e941..798a1591c73c4f4f3f37b015d20ec31c40aaa939 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -46,16 +46,22 @@ import unittest from absl.testing import parameterized import six -from tensorflow.contrib.cluster_resolver import TPUClusterResolver +from tensorflow.contrib import cluster_resolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib +from tensorflow.contrib.distribute.python import parameter_server_strategy from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2 from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.contrib.tpu.python.tpu import device_assignment as device_assignment_lib from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_keras_v2 +from tensorflow.python.keras.optimizer_v2 import adam as adam_keras_v2 +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras_v2 +from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_keras_v2 from tensorflow.python.training import adagrad from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent @@ -192,7 +198,7 @@ def _augment_with_special_arguments(test_method): kwargs_to_pass[arg] = kwargs[arg] if mode == "eager": - with ops.Graph().as_default(), context.eager_mode(): + with context.eager_mode(): if distribution: kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) @@ -226,7 +232,7 @@ def combine(**kwargs): if not kwargs: return [OrderedDict()] - sort_by_key = lambda k: k[0][0] + sort_by_key = lambda k: k[0] kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) first = list(kwargs.items())[0] @@ -321,22 +327,46 @@ class NamedDistribution(object): return self._required_tpu +def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs): + def _create_tpu_strategy(): + resolver = cluster_resolver.TPUClusterResolver("") + topology = tpu_lib.initialize_tpu_system(resolver) + device_assignment = None + if use_single_core: + device_assignment = device_assignment_lib.DeviceAssignment( + topology, core_assignment=device_assignment_lib. + SINGLE_CORE_ASSIGNMENT) + + strategy = tpu_lib.TPUStrategy(resolver, steps_per_run=steps_per_run, + device_assignment=device_assignment, + **kwargs) + return strategy + return _create_tpu_strategy + + # pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", - distribution_strategy_context._get_default_distribution_strategy, # pylint: disable=protected-access + distribution_strategy_context._get_default_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) tpu_strategy = NamedDistribution( - "TPU", lambda: tpu_lib.TPUStrategy( - TPUClusterResolver(""), steps_per_run=2), + "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True) tpu_strategy_one_step = NamedDistribution( - "TPUOneStep", lambda: tpu_lib.TPUStrategy( - TPUClusterResolver(""), steps_per_run=1), + "TPUOneStep", _get_tpu_strategy_creator(steps_per_run=1), + required_tpu=True) +tpu_strategy_one_core = NamedDistribution( + "TPUOneCore", _get_tpu_strategy_creator( + steps_per_run=2, use_single_core=True), required_tpu=True) +tpu_strategy_one_step_one_core = NamedDistribution( + "TPUOneStepOneCore", _get_tpu_strategy_creator( + steps_per_run=1, use_single_core=True), + required_tpu=True) + mirrored_strategy_with_one_cpu = NamedDistribution( "Mirrored1CPU", lambda: mirrored_lib.MirroredStrategy(["/cpu:0"])) @@ -367,6 +397,11 @@ core_mirrored_strategy_with_two_gpus = NamedDistribution( "CoreMirrored2GPUs", lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) +parameter_server_strategy_with_two_gpus = NamedDistribution( + "ParameterServer2GPUs", + lambda: parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2), + required_gpus=2) gradient_descent_optimizer_v1_fn = NamedObject( @@ -386,10 +421,20 @@ gradient_descent_optimizer_v2_fn = NamedObject( adagrad_optimizer_v2_fn = NamedObject( "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001)) adam_optimizer_v2_fn = NamedObject( - "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) + "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1.0)) optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn] +gradient_descent_optimizer_keras_v2_fn = NamedObject( + "GradientDescentKerasV2", + lambda: gradient_descent_keras_v2.SGD(0.2)) +adagrad_optimizer_keras_v2_fn = NamedObject( + "AdagradKerasV2", lambda: adagrad_keras_v2.Adagrad(0.001)) +adam_optimizer_keras_v2_fn = NamedObject( + "AdamKerasV2", lambda: adam_keras_v2.Adam(0.001, epsilon=1.0)) +rmsprop_optimizer_keras_v2_fn = NamedObject( + "RmsPropKerasV2", lambda: rmsprop_keras_v2.RMSprop(0.001)) + graph_and_eager_modes = ["graph", "eager"] diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py index 86aa48cea889c6c2ce169b18bcabb6d08890fbed..9f3deadbec98c4f66061ca29b4d29a74b8de40b1 100644 --- a/tensorflow/contrib/distribute/python/combinations_test.py +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -42,6 +42,14 @@ class TestingCombinationsTest(test.TestCase): "b": 3 }], combinations.combine(a=[1, 2], b=[2, 3])) + def test_arguments_sorted(self): + self.assertEqual([ + OrderedDict([("aa", 1), ("ab", 2)]), + OrderedDict([("aa", 1), ("ab", 3)]), + OrderedDict([("aa", 2), ("ab", 2)]), + OrderedDict([("aa", 2), ("ab", 3)]) + ], combinations.combine(ab=[2, 3], aa=[1, 2])) + def test_combine_single_parameter(self): self.assertEqual([{ "a": 1, diff --git a/tensorflow/contrib/distribute/python/cross_device_ops_test.py b/tensorflow/contrib/distribute/python/cross_device_ops_test.py index 3602cc92094ff607187f19e9e1c0ebde45aa6787..54cce2988383fcf5e063726948fbbf62c7094ce5 100644 --- a/tensorflow/contrib/distribute/python/cross_device_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_ops_test.py @@ -40,8 +40,16 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +def _get_devices(devices): + if isinstance(devices, (tuple, list)): + return tuple(device_util.resolve(d) for d in devices) + elif isinstance(devices, value_lib.DistributedValues): + return devices.devices + return (device_util.resolve(devices),) + + def _make_per_replica(values, devices, regroup=False): - devices = cross_device_ops_lib.get_devices_from(devices) + devices = _get_devices(devices) assert len(values) == len(devices) # We simulate the result of regroup called on PerReplica which strips the @@ -51,12 +59,12 @@ def _make_per_replica(values, devices, regroup=False): placed_v = array_ops.identity(values[0]) return placed_v - index = {} + index = [] for d, v in zip(devices, values): with ops.device(d): placed_v = array_ops.identity(v) - index[d] = placed_v - return value_lib.PerReplica(index) + index.append(placed_v) + return value_lib.PerReplica(value_lib.ReplicaDeviceMap(devices), index) # pylint: disable=g-doc-args,g-doc-return-or-yield @@ -66,9 +74,9 @@ def _fake_mirrored(value, devices): All components of the returned Mirrored have the same objects, which is not true in reality. """ - devices = cross_device_ops_lib.get_devices_from(devices) - return value_lib.Mirrored( - {d: v for d, v in zip(devices, [value] * len(devices))}) + devices = _get_devices(devices) + return value_lib.Mirrored(value_lib.ReplicaDeviceMap(devices), + [value] * len(devices)) def _make_indexed_slices(values, indices, dense_shape, device): @@ -81,9 +89,9 @@ def _make_indexed_slices(values, indices, dense_shape, device): def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): - return value_lib.Mirrored({ - d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices - }) + values = [_make_indexed_slices(values, indices, dense_shape, d) + for d in devices] + return value_lib.Mirrored(value_lib.ReplicaDeviceMap(devices), values) _cpu_device = "/device:CPU:0" @@ -107,16 +115,16 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): else: self.assertEqual(type(left), type(right)) self.assertEqual(set(left.devices), set(right.devices)) - if isinstance(list(left._index.values())[0], ops.IndexedSlices): - for (d, v) in left._index.items(): - self._assert_indexed_slices_equal(v, right._index[d]) + if isinstance(left.values[0], ops.IndexedSlices): + for d in left.devices: + self._assert_indexed_slices_equal(left.get(d), right.get(d)) elif context.executing_eagerly(): - self.assertEqual([v.numpy() for v in left._index.values()], - list(right._index.values())) + self.assertEqual([v.numpy() for v in left.values], + list(right.values)) else: with self.cached_session() as sess: self.assertEqual( - sess.run(list(left._index.values())), list(right._index.values())) + sess.run(list(left.values)), list(right.values)) def _testReductionAndBroadcast(self, cross_device_ops, distribution): devices = distribution.extended.worker_devices @@ -280,7 +288,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): devices = ["/cpu:0", "/gpu:0"] t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) - per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) + per_replica = value_lib.PerReplica( + value_lib.ReplicaDeviceMap(devices), (t0, t1)) result = cross_device_ops_lib._simple_reduce( per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM) @@ -314,7 +323,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) t1 = _make_indexed_slices( [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) - per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) + per_replica = value_lib.PerReplica( + value_lib.ReplicaDeviceMap(devices), (t0, t1)) if batch_reduce: result = cross_device_ops_instance.batch_reduce( @@ -392,18 +402,16 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase, # pylint: disable=g-long-lambda combinations.NamedDistribution( "CoreMirroredCPU", - lambda: mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=0), + lambda: mirrored_strategy.CoreMirroredStrategy(["/device:CPU:0"]), required_gpus=0), combinations.NamedDistribution( "CoreMirrored1GPU", - lambda: mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=1), + lambda: mirrored_strategy.CoreMirroredStrategy(["/device:GPU:0"]), required_gpus=1), combinations.NamedDistribution( "CoreMirrored2GPUs", lambda: mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=2), + ["/device:GPU:0", "/device:GPU:1"]), required_gpus=2), ], mode=["graph"]) @@ -476,8 +484,8 @@ class MultiWorkerCollectiveAllReduceTest( run_options.experimental.collective_graph_key = 6 left_values = np.array( - sess.run(list(left._index.values()), options=run_options)).flatten() - right_values = np.array(list(right._index.values())).flatten() + sess.run(list(left.values), options=run_options)).flatten() + right_values = np.array(list(right.values)).flatten() self.assertEqual(len(left_values), len(right_values)) for l, r in zip(left_values, right_values): self.assertEqual(l, r) @@ -498,7 +506,7 @@ class MultiWorkerCollectiveAllReduceTest( # Collective ops doesn't support scalar tensors, so we have to construct # 1-d tensors. values = [constant_op.constant([float(d)]) for d in range(len(devices))] - per_replica = _make_per_replica(values, devices, regroup=True) + per_replica = _make_per_replica(values, devices) mean = np.array([(len(devices) - 1.) / 2.]) values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] diff --git a/tensorflow/contrib/distribute/python/cross_device_utils_test.py b/tensorflow/contrib/distribute/python/cross_device_utils_test.py index 2303a31677afbd12a0b8e7eea3ecf7c7736c46ad..275aac2eeca575e927878d1ece63ce37ed38e8a0 100644 --- a/tensorflow/contrib/distribute/python/cross_device_utils_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_utils_test.py @@ -103,7 +103,8 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1}) + device_map = value_lib.ReplicaDeviceMap(("/gpu:0", "/cpu:0")) + per_replica = value_lib.PerReplica(device_map, (t0, t1)) self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica)) @combinations.generate(combinations.combine( diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 8c0f59420215f6dfcfd1a17565ba6d1c337696c5..3f55a8a1c8b88d1b8e4031547fa3fbe519983630 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -248,6 +248,12 @@ class DistributeCoordinatorIntegrationTest( ]) self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape) + def _get_strategy_object(self, strategy_cls): + if strategy_cls == mirrored_strategy.CoreMirroredStrategy: + return strategy_cls(mirrored_strategy.all_local_devices()) + else: + return strategy_cls(num_gpus_per_worker=context.num_gpus()) + @combinations.generate( combinations.combine( mode=["graph"], @@ -266,12 +272,10 @@ class DistributeCoordinatorIntegrationTest( required_gpus=[0, 1])) def test_complete_flow_standalone_client(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -298,12 +302,10 @@ class DistributeCoordinatorIntegrationTest( required_gpus=[0, 1])) def test_estimator_standalone_client(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -347,16 +349,14 @@ class DistributeCoordinatorIntegrationTest( required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_between_graph( self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) - if (context.num_gpus() < 2 and eval_distribute_cls == collective_all_reduce_strategy.CollectiveAllReduceStrategy): self.skipTest("`CollectiveAllReduceStrategy` needs at least two towers.") + train_distribute = self._get_strategy_object(train_distribute_cls) + if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -375,11 +375,13 @@ class DistributeCoordinatorIntegrationTest( threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) + threads_to_join = [] for task_type, ts in threads.items(): if task_type == PS: continue for t in ts: - t.join() + threads_to_join.append(t) + self.join_independent_workers(threads_to_join) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) @@ -399,12 +401,10 @@ class DistributeCoordinatorIntegrationTest( required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -415,8 +415,7 @@ class DistributeCoordinatorIntegrationTest( threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, cluster_spec, train_distribute, eval_distribute) - threads[WORKER][0].join() - threads[EVALUATOR][0].join() + self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]]) estimator = self._get_estimator(train_distribute, eval_distribute) self._inspect_train_and_eval_events(estimator) @@ -453,7 +452,7 @@ class RunConfigTest(test.TestCase): run_config_lib.RunConfig( experimental_distribute=DistributeConfig( train_distribute=mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=2))) + ["/device:GPU:0", "/device:GPU:1"]))) def test_should_run_distribute_coordinator(self): """Tests that should_run_distribute_coordinator return a correct value.""" @@ -477,11 +476,11 @@ class RunConfigTest(test.TestCase): config_with_train_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( train_distribute=mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=2))) + ["/device:GPU:0", "/device:GPU:1"]))) config_with_eval_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( eval_distribute=mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=2))) + ["/device:GPU:0", "/device:GPU:1"]))) self.assertTrue( dc_training.should_run_distribute_coordinator( config_with_train_distribute)) @@ -495,7 +494,7 @@ class RunConfigTest(test.TestCase): config = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( train_distribute=mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=2))) + ["/device:GPU:0", "/device:GPU:1"]))) self.assertFalse(dc_training.should_run_distribute_coordinator(config)) def test_init_run_config_duplicate_distribute(self): diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD index 84b106545e1326fddd3ed299462534af982dc102..5f89df5824a8d03198987a6fa3d21e2330deedf0 100644 --- a/tensorflow/contrib/distribute/python/examples/BUILD +++ b/tensorflow/contrib/distribute/python/examples/BUILD @@ -31,6 +31,12 @@ py_binary( py_binary( name = "keras_mnist", + srcs = ["keras_mnist.py"], + deps = [":keras_mnist_lib"], +) + +py_library( + name = "keras_mnist_lib", srcs = [ "keras_mnist.py", ], @@ -39,3 +45,14 @@ py_binary( "//third_party/py/numpy", ], ) + +py_binary( + name = "mnist_eager_multigpu", + srcs = [ + "mnist_eager_multigpu.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index 8b6487252df54dc18cc0763fb1c58a190faad88a..1ce91ecaf22a80a53124c8f00fac05c6b4711ed9 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -20,6 +20,10 @@ from __future__ import print_function import tensorflow as tf +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.keras.optimizer_v2 import rmsprop + + NUM_CLASSES = 10 @@ -105,22 +109,21 @@ def main(_): tf.enable_eager_execution() train_ds, eval_ds, input_shape = get_input_datasets() - model = get_model(input_shape) # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or # the `devices` argument then all the GPUs available on the machine are used. - strategy = tf.contrib.distribute.MirroredStrategy(['/gpu:0', '/cpu:0']) - - # TODO(priyag): Use RMSPropOptimizer when it works with eager mode. - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) - - # Compile the model by passing the distribution strategy object to the - # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed - # based on the strategy instantiated. - model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=optimizer, - metrics=['accuracy'], - distribute=strategy) + # TODO(priyag): Use `tf.distribute.MirroredStrategy` once available. + strategy = mirrored_strategy.MirroredStrategy(['/gpu:0', '/cpu:0']) + + # Create and compile the model under Distribution strategy scope. + # `fit`, `evaluate` and `predict` will be distributed based on the strategy + # model was compiled with. + with strategy.scope(): + model = get_model(input_shape) + optimizer = rmsprop.RMSProp(learning_rate=0.001) + model.compile(loss=tf.keras.losses.categorical_crossentropy, + optimizer=optimizer, + metrics=['accuracy']) # Train the model with the train dataset. model.fit(x=train_ds, epochs=20, steps_per_epoch=468) diff --git a/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py new file mode 100644 index 0000000000000000000000000000000000000000..c045a5586b9dad371d8c505f9cac4b792dd157fd --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py @@ -0,0 +1,169 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run MNIST on multiple GPUs on using MirroredStrategy with eager execution. + +By default, runs on all available GPUs, or CPU if no GPUs are available. + +NOTE: Currently, this takes more time than when running MNIST in eager without +MirroredStrategy because of a number overheads. Therefore, this is just a +proof of concept right now and cannot be used to actually scale up training. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import app +from absl import flags +import numpy as np +import tensorflow.compat.v2 as tf + +flags.DEFINE_integer("num_gpus", None, "How many GPUs should we run on?" + "Defaults to all available GPUs, otherwise CPU.") +flags.DEFINE_integer("batch_size", 64, + "What should be the size of each batch?") +flags.DEFINE_integer("num_epochs", 10, "How many epochs to run?") +flags.DEFINE_float("learning_rate", 0.01, "Learning Rate") +flags.DEFINE_float("momentum", 0.5, "SGD momentum") +flags.DEFINE_boolean("use_function", False, + "Should we wrap the step in a tf.function.") + +FLAGS = flags.FLAGS +NUM_TRAIN_IMAGES = 60000 + + +def create_model(): + max_pool = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding="same") + # The model consists of a sequential chain of layers, so tf.keras.Sequential + # (a subclass of tf.keras.Model) makes for a compact description. + return tf.keras.Sequential([ + tf.keras.layers.Reshape( + target_shape=[28, 28, 1], + input_shape=(28, 28,)), + tf.keras.layers.Conv2D(2, 5, padding="same", activation=tf.nn.relu), + max_pool, + tf.keras.layers.Conv2D(4, 5, padding="same", activation=tf.nn.relu), + max_pool, + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(32, activation=tf.nn.relu), + tf.keras.layers.Dropout(0.4), + tf.keras.layers.Dense(10)]) + + +def compute_loss(logits, labels): + loss = tf.reduce_sum( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels)) + # Scale loss by global batch size. + return loss * (1. / FLAGS.batch_size) + + +def mnist_datasets(): + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + # Numpy defaults to dtype=float64; TF defaults to float32. Stick with float32. + x_train, x_test = x_train / np.float32(255), x_test / np.float32(255) + y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64) + # TODO(priyag): `strategy.make_numpy_iterator` can be used directly instead of + # converting to datasets. + train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + return train_dataset, test_dataset + + +def main(unused_argv): + """Run a CNN model on MNIST data to demonstrate DistributedStrategies.""" + + tf.enable_v2_behavior() + + num_gpus = FLAGS.num_gpus + if num_gpus is None: + devices = None + elif num_gpus == 0: + devices = ["/device:CPU:0"] + else: + devices = ["/device:GPU:{}".format(i) for i in range(num_gpus)] + strategy = tf.distribute.MirroredStrategy(devices) + + with strategy.scope(): + train_ds, test_ds = mnist_datasets() + train_ds = train_ds.shuffle(NUM_TRAIN_IMAGES).batch(FLAGS.batch_size) + test_ds = test_ds.batch(FLAGS.batch_size) + + model = create_model() + optimizer = tf.keras.optimizers.SGD(FLAGS.learning_rate, FLAGS.momentum) + training_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32) + training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + "training_accuracy", dtype=tf.float32) + test_loss = tf.keras.metrics.Mean("test_loss", dtype=tf.float32) + test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + "test_accuracy", dtype=tf.float32) + + def train_step(inputs): + images, labels = inputs + with tf.GradientTape() as tape: + logits = model(images, training=True) + loss = compute_loss(logits, labels) + grads = tape.gradient(loss, model.variables) + optimizer.apply_gradients(zip(grads, model.variables)) + training_loss.update_state(loss) + training_accuracy.update_state(labels, logits) + + def test_step(inputs): + images, labels = inputs + logits = model(images, training=False) + loss = compute_loss(logits, labels) + test_loss.update_state(loss) + test_accuracy.update_state(labels, logits) + + train_iterator = strategy.make_dataset_iterator(train_ds) + test_iterator = strategy.make_dataset_iterator(test_ds) + + for epoch in range(0, FLAGS.num_epochs): + # TODO(b/123315763): Create the tf.function outside this loop once we are + # able to initialize iterator in eager mode. + dist_train = lambda it: strategy.experimental_run(train_step, it) + dist_test = lambda it: strategy.experimental_run(test_step, it) + if FLAGS.use_function: + dist_train = tf.function(dist_train) + dist_test = tf.function(dist_test) + + # Train + print("Starting epoch {}".format(epoch)) + train_iterator.initialize() + while True: + try: + dist_train(train_iterator) + except tf.errors.OutOfRangeError: + break + print("Training loss: {:0.4f}, accuracy: {:0.2f}%".format( + training_loss.result(), training_accuracy.result() * 100)) + training_loss.reset_states() + training_accuracy.reset_states() + + # Test + test_iterator.initialize() + while True: + try: + dist_test(test_iterator) + except tf.errors.OutOfRangeError: + break + print("Test loss: {:0.4f}, accuracy: {:0.2f}%".format( + test_loss.result(), test_accuracy.result() * 100)) + test_loss.reset_states() + test_accuracy.reset_states() + + +if __name__ == "__main__": + app.run(main) diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py new file mode 100644 index 0000000000000000000000000000000000000000..10a58316ec5b3d9d968a88c5c39ff70c277daa65 --- /dev/null +++ b/tensorflow/contrib/distribute/python/input_lib_test.py @@ -0,0 +1,246 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the input_lib library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.python.data.experimental.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import values +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import errors +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.util import nest + + +class InputIteratorTestBase(test.TestCase): + + def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) + + if input_type == "input_fn": + input_contexts = [ + distribute_lib.InputContext() for _ in worker_device_pairs] + input_fn = lambda _: dataset_fn() + iterator = input_lib.InputFunctionIterator( + input_fn, input_workers, input_contexts) + else: + iterator = input_lib.DatasetIterator( + dataset_fn(), input_workers, split_batch_by) + + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) + self.assertAllEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_replica(r, next_element) + for r in range(len(devices))]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) + self.assertAllEqual(expected_value, computed_value) + + +class InputIteratorSingleWorkerTest(InputIteratorTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"])) + def testOneDeviceCPU(self, input_type): + worker_device_pairs = [("", ["/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesOneGPUOneCPU(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTupleDataset(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testUnevenDatasetBatches(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["dataset"], + split_batch_by=[None, 2], + required_gpus=1)) + def testBatchSplitting(self, input_type, split_batch_by): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + batch_size = 10 + dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) + + updated_batch_size = ( + batch_size // split_batch_by if split_batch_by else batch_size) + expected_values = [[range(i, i+updated_batch_size), + range(i+updated_batch_size, i+2*updated_batch_size)] + for i in range(0, 100, updated_batch_size*2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, + split_batch_by=split_batch_by) + + +class InputIteratorMultiWorkerTest( + multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, + parameterized.TestCase): + + def _cpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] + + def _cpu_and_one_gpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), + ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ]) + ] + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testOneDevicePerWorker(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesPerWorker(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 1, 0, 1], [2, 3, 2, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testTupleDataset(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(4) + dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + + +class SplitDatasetBatchTest(test.TestCase): + + def testBatchDataset(self): + dataset = dataset_ops.Dataset.range(100).batch(20) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + def testMapAndBatchDataset(self): + dataset = dataset_ops.Dataset.range(100) + dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + def testPrefetchDataset(self): + dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) + split_batch_by = 2 + result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) + expected_values = [range(i, i+10) for i in range(0, 100, 10)] + result = [self.evaluate(el) for el in result_dataset] + self.assertAllEqual(expected_values, result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9a581e7141af4a6625246539bc48835e6a920887 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py @@ -0,0 +1,1083 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.keras models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import random_seed +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.ops.parsing_ops import gen_parsing_ops +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import rmsprop +from tensorflow.python.training.mode_keys import ModeKeys + +_RANDOM_SEED = 1337 +_TRAIN_SIZE = 200 +_INPUT_SIZE = (10,) +_NUM_CLASS = 2 + + +# TODO(anjalisridhar): Add a decorator that will allow us to run these tests as +# part of the tf.keras unit tests suite. +def simple_sequential_model(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) + model.add(keras.layers.Dropout(0.1)) + model.add(keras.layers.Dense(_NUM_CLASS, activation='softmax')) + return model + + +def simple_functional_model(): + a = keras.layers.Input(shape=_INPUT_SIZE) + b = keras.layers.Dense(16, activation='relu')(a) + b = keras.layers.Dropout(0.1)(b) + b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b) + model = keras.models.Model(inputs=[a], outputs=[b]) + return model + + +def multi_inputs_multi_outputs_model(): + input_a = keras.layers.Input(shape=(16,), name='input_a') + input_b = keras.layers.Input(shape=(16,), name='input_b') + input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m') + dense = keras.layers.Dense(8, name='dense_1') + + interm_a = dense(input_a) + # Read m + interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m) + interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a]) + interm_b = dense(input_b) + merged = keras.layers.concatenate([interm_s, interm_b], name='merge') + output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged) + output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged) + model = keras.models.Model( + inputs=[input_a, input_b, input_m], outputs=[output_c, output_d]) + model.compile( + loss='categorical_crossentropy', + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + metrics={ + 'dense_2': 'categorical_accuracy', + 'dense_3': 'categorical_accuracy' + }) + return model + + +def get_ds_train_input_fn(): + np.random.seed(_RANDOM_SEED) + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_train = keras.utils.to_categorical(y_train) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset = dataset.batch(32) + return dataset + + +def get_ds_test_input_fn(): + np.random.seed(_RANDOM_SEED) + _, (x_test, y_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_test = keras.utils.to_categorical(y_test) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_test, y_test)) + dataset = dataset.batch(32) + return dataset + + +def get_multi_inputs_multi_outputs_data(): + (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(16,), + num_classes=3, + random_seed=_RANDOM_SEED) + (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(16,), + num_classes=2, + random_seed=_RANDOM_SEED) + (m_train, _), (m_test, _) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(8,), + num_classes=2, + random_seed=_RANDOM_SEED) + + c_train = keras.utils.to_categorical(c_train) + c_test = keras.utils.to_categorical(c_test) + d_train = keras.utils.to_categorical(d_train) + d_test = keras.utils.to_categorical(d_test) + + train_data = { + 'input_a': a_train, + 'input_b': b_train, + 'input_m': m_train, + 'output_c': c_train, + 'output_d': d_train + } + test_data = { + 'input_a': a_test, + 'input_b': b_test, + 'input_m': m_test, + 'output_c': c_test, + 'output_d': d_test + } + + return (train_data, test_data) + + +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) + # TPUs currently require fully defined input shapes, drop_remainder ensures + # the input will have fully defined shapes. + if isinstance(distribution, tpu_strategy.TPUStrategy): + return dataset.batch(batch_size, drop_remainder=True) + else: + return dataset.batch(batch_size) + + +def get_model(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + return model + + +def get_dataset(distribution): + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 10, distribution) + return dataset + + +def get_predict_dataset(distribution): + inputs = np.zeros((10, 3), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 10, distribution) + return dataset + + +def multi_input_output_model(): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(5,), name='input_b') + # TODO(anjalisridhar): Change the output dimension of the second Dense layer + # once the iterator output validation issue has been fixed. + dense_1 = keras.layers.Dense(7, name='dense_1') + dense_2 = keras.layers.Dense(7, name='dense_2') + c = dense_1(a) + d = dense_2(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + model = keras.models.Model([a, b], [d, e]) + return model + + +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, + x_train, y_train, x_predict): + """Generates the inputs for correctness check when enable Keras with DS.""" + training_epochs = 2 + global_batch_size = 64 + batch_size = global_batch_size + # TODO(b/118776054): Use global batch size for Keras/DS support. + use_per_core_batch_size = ( + with_distribution and + not distributed_training_utils.global_batch_size_supported( + with_distribution)) + if use_per_core_batch_size: + batch_size //= with_distribution.num_replicas_in_sync + + if use_numpy: + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': training_epochs, + 'shuffle': False, + } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } + predict_inputs = { + 'x': np.array(x_predict, dtype=np.float32), + } + else: + # For dataset inputs, we do not pass batch_size to + # keras.fit/evaluate/predict. The batch size is part of the dataset. + train_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper( + train_dataset, batch_size, with_distribution, repeat=training_epochs) + + training_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'epochs': training_epochs, + 'shuffle': False, + 'steps_per_epoch': len(x_train) // global_batch_size, + } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + + predict_batch_size = len(x_predict) + if use_per_core_batch_size: + predict_batch_size //= with_distribution.num_replicas_in_sync + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) + predict_dataset = batch_wrapper(predict_dataset, + predict_batch_size, with_distribution) + predict_inputs = { + 'steps': 1, + 'x': predict_dataset, + } + + return training_inputs, eval_inputs, predict_inputs + + +strategies_minus_tpu = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus] + +tpu_strategies = [ + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step] + + +def strategy_minus_tpu_combinations(): + return combinations.combine( + distribution=strategies_minus_tpu, + mode=['graph', 'eager']) + + +def tpu_strategy_combinations(): + return combinations.combine( + distribution=tpu_strategies, + mode=['graph']) + + +def all_strategy_combinations(): + return strategy_minus_tpu_combinations() + tpu_strategy_combinations() + + +def strategy_and_optimizer_combinations(): + return combinations.times( + all_strategy_combinations(), + combinations.combine(optimizer=[ + combinations.adagrad_optimizer_v1_fn, + combinations.adagrad_optimizer_keras_v2_fn, + combinations.adam_optimizer_v1_fn, + combinations.adam_optimizer_keras_v2_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_keras_v2_fn, + combinations.rmsprop_optimizer_v1_fn, + combinations.rmsprop_optimizer_keras_v2_fn + ])) + + +def strategy_and_input_combinations(): + return ( + combinations.times( + combinations.combine(distribution=strategies_minus_tpu), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]) + + combinations.combine(mode=['eager'], + use_numpy=[False], + use_validation_data=[False])) + + combinations.times( + combinations.combine(distribution=tpu_strategies), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]))) + + +def strategy_for_numpy_input_combinations(): + return combinations.combine( + distribution=strategies_minus_tpu + tpu_strategies, + mode=['graph']) + + +class TestDistributionStrategyWithNumpyArrays(test.TestCase, + parameterized.TestCase): + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calling_model_with_numpy_arrays(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calling_model_with_nested_numpy_arrays(self, distribution): + with self.cached_session(): + model = multi_input_output_model() + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) + input_b_np = np.asarray(np.random.random((64, 5)), dtype=np.float32) + inputs = [input_a_np, input_b_np] + + output_d_np = np.asarray(np.random.random((64, 7)), dtype=np.float32) + output_e_np = np.asarray(np.random.random((64, 7)), dtype=np.float32) + targets = [output_d_np, output_e_np] + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=8, verbose=0) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + + @combinations.generate(combinations.combine( + distribution=strategies_minus_tpu, mode=['graph'])) + def test_numpy_with_sample_weights(self, distribution): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + inputs = np.zeros((20, 3), np.float32) + targets = np.zeros((20, 4), np.float32) + sample_weights = np.ones((20), np.float32) + + model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, + steps_per_epoch=2, verbose=1) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_flatten_predict_outputs(self, distribution): + with self.cached_session(): + model = multi_input_output_model() + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # We take 6 input samples with each input having a dimension of 3 or 5. + input_a_np = np.asarray(np.random.random((6, 3)), dtype=np.float32) + input_b_np = np.asarray(np.random.random((6, 5)), dtype=np.float32) + inputs = [input_a_np, input_b_np] + + outs = model.predict(inputs, steps=1) + # `predict` a list that is equal in length to the number of model outputs. + # In this test our model has two outputs and each element of `outs` + # corresponds to all the samples of one of the model outputs. + self.assertLen(outs, 2) + # Each of the output samples have a dimension of 7. We should process all + # the available input samples(6). + self.assertAllEqual([6, 7], outs[0].shape) + self.assertAllEqual([6, 7], outs[1].shape) + + +class TestDistributionStrategyWithDatasets(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_calling_model_on_same_dataset(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + # Call fit with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.predict(get_predict_dataset(distribution), steps=2) + + @combinations.generate(all_strategy_combinations()) + def test_model_interleaved_eval_same_as_direct_eval(self, distribution): + with self.cached_session(): + user_controlled_model = get_model() + user_controlled_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) + + interleaved_model = get_model() + interleaved_model.set_weights(user_controlled_model.get_weights()) + interleaved_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) + + dataset = get_dataset(distribution) + + # Call fit with validation interleaved + interleaved_output = interleaved_model.fit( + dataset, epochs=2, steps_per_epoch=2, verbose=1, + validation_data=dataset, validation_steps=2, shuffle=False) + + # Manually control the validation running after each epoch. + user_controlled_output = [] + for _ in range(2): + user_controlled_model.fit( + dataset, epochs=1, steps_per_epoch=2, verbose=1, shuffle=False) + user_controlled_output.append( + user_controlled_model.evaluate(dataset, steps=2)) + + self.assertEqual(interleaved_output.history['val_loss'], + [x[0] for x in user_controlled_output]) + self.assertEqual(interleaved_output.history['val_mean_absolute_error'], + [x[1] for x in user_controlled_output]) + self.assertEqual(interleaved_output.history['val_categorical_accuracy'], + [x[2] for x in user_controlled_output]) + + # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work + # as clone_model's input_tensors argument only seems to accept list and not + # tuples or dict. + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): + with self.cached_session(): + model = multi_input_output_model() + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 5)) + output_d_np = np.random.random((10, 7)) + output_e_np = np.random.random((10, 7)) + + # Test with tuples + dataset_tuple = dataset_ops.Dataset.from_tensor_slices(( + (input_a_np, input_b_np), (output_d_np, output_e_np))) + dataset_tuple = dataset_tuple.repeat(100) + dataset_tuple = dataset_tuple.batch(10) + + model.fit(dataset_tuple, epochs=1, steps_per_epoch=2, verbose=1) + + # Test with dict + dataset_dict = dataset_ops.Dataset.from_tensor_slices(( + {'input_a': input_a_np, 'input_b': input_b_np}, + (output_d_np, output_e_np))) + dataset_dict = dataset_dict.repeat(100) + dataset_dict = dataset_dict.batch(10) + + model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) + + @combinations.generate(all_strategy_combinations()) + def test_fit_eval_and_predict_methods_on_dataset(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(get_predict_dataset(distribution), steps=2) + + @combinations.generate(strategy_and_optimizer_combinations()) + def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer): + with self.cached_session(): + model = get_model() + + loss = 'mse' + model.compile(optimizer(), loss, distribute=distribution) + + dataset = get_dataset(distribution) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(get_predict_dataset(distribution), steps=2) + + @combinations.generate(strategy_minus_tpu_combinations()) + def test_dataset_with_sample_weights(self, distribution): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + inputs = np.zeros((10, 3), np.float32) + targets = np.zeros((10, 4), np.float32) + sample_weights = np.ones((10), np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, + sample_weights)) + dataset = dataset.repeat() + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_wrong_input_shape(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_no_batch_input_validation(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[combinations.tpu_strategy_one_step], + mode=['graph'])) + def test_dataset_input_shape_fully_defined(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + dataset = get_dataset(distribution) + # Input shapes are not fully known. Batch dimension is unknown as we are + # not using the drop_remainder argument. + dataset = dataset.repeat(100).batch(10) + + with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_learning_phase_value(self, distribution): + # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare + # meaningful values. Currently we don't pass the learning phase if the + # Lambda layer uses the learning phase. + with self.cached_session(): + x = keras.layers.Input(shape=(1,), name='input') + y = keras.layers.Dense(1, kernel_initializer='ones')(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + initial_weights = model.get_weights() + + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + batch_size = 8 + if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy): + # CoreMirroredStrategy uses global batch size. + batch_size = 8 * distribution.num_replicas_in_sync + + inputs = np.ones((10, 1), dtype=np.float32) + targets = np.ones((10, 1), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat().batch(batch_size) + hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) + self.assertAlmostEqual(hist.history['acc'][0], 0, 0) + + model.set_weights(initial_weights) + # TODO(psv/anjalisridhar): Enable these lines after we fix b/117431185. + # evaluate_output = model.evaluate(dataset, steps=20) + # self.assertAlmostEqual(evaluate_output[1], 1, 0) + + inputs = np.ones((10, 1), dtype=np.float32) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + + predict_dataset = predict_dataset.repeat().batch(batch_size) + output = model.predict(predict_dataset, steps=10) + # `predict` runs for 10 steps + ref_output = np.ones((160, 1), dtype=np.float32) + self.assertArrayNear(output, ref_output, 1e-1) + + @combinations.generate(strategy_minus_tpu_combinations()) + def testOptimizerWithCallbacks(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent_keras.SGD(0.01) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + dataset = get_dataset(distribution) + + def schedule(_): + return 0.001 + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + grouped_models = distribution.unwrap( + distributed_training_utils.get_distributed_model( + model, ModeKeys.TRAIN)) + with distribution.scope(): + for m in grouped_models: + self.assertAllClose(0.001, keras.backend.get_value( + m.optimizer.lr), atol=1e-05, rtol=1e-05) + + +class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_unsupported_features(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not ' + 'supported when input `x` is a dataset or a ' + 'dataset iterator.+'): + model.fit(dataset, + epochs=1, steps_per_epoch=2, verbose=0, + validation_split=0.5, validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + ValueError, '`sample_weight` argument is not supported when input ' + '`x` is a dataset or a dataset iterator.'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test with not specifying the `steps` argument for dataset with + # infinite cardinality. + dataset = dataset.repeat() + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): + model.evaluate(dataset, verbose=0) + + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): + model.predict(dataset, verbose=0) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_calling_with_unsupported_predefined_callbacks(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + dataset = get_dataset(distribution) + + def schedule(_): + return 0.001 + with self.assertRaisesRegexp(ValueError, + 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + + with self.assertRaisesRegexp(ValueError, + 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.ReduceLROnPlateau()]) + + +class TestDistributionStrategyWithLossMasking(test.TestCase, + parameterized.TestCase): + + # TODO(priyag): Enable all strategies for this test. Currently it does not + # work for TPU due to some invalid datatype. + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_masking(self, distribution): + with self.cached_session(): + np.random.seed(1337) + x = np.array([[[1], [1]], [[0], [0]]]) + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=distribution) + y = np.array([[[1], [1]], [[1], [1]]]) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) + self.assertEqual(hist.history['loss'][0], 0) + + +class TestDistributionStrategyWithNormalizationLayer( + test.TestCase, parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_batchnorm_correctness(self, distribution): + with self.cached_session(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) + model.add(norm) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=distribution) + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + x = x.astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 32, distribution) + + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x) + predict_dataset = predict_dataset.repeat(100) + predict_dataset = batch_wrapper(predict_dataset, 32, distribution) + + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) + out = model.predict(predict_dataset, steps=2) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + + +class TestDistributionStrategyCorrectness(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_metric_correctness(self, distribution): + with self.cached_session(): + keras.backend.set_image_data_format('channels_last') + num_samples = 10000 + + x_train = np.random.randint(0, 2, num_samples) + x_train = np.reshape(x_train, (num_samples, 1)) + y_train = x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + + # Create identity model. + model = keras.Sequential() + model.add( + keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones')) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + metrics=[keras.metrics.BinaryAccuracy()], + distribute=distribution) + + batch_size = 64 + if not distributed_training_utils.global_batch_size_supported( + distribution): + batch_size //= distribution.num_replicas_in_sync + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + train_dataset = batch_wrapper(train_dataset, batch_size, distribution) + + history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) + + @combinations.generate(all_strategy_combinations()) + def test_eval_metrics_correctness(self, distribution): + with self.cached_session(): + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', keras.metrics.BinaryAccuracy()], + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + distribute=distribution) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)).astype('float32') + y = np.ones((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @combinations.generate(strategy_and_input_combinations()) + def test_correctness(self, distribution, use_numpy, use_validation_data): + with self.cached_session(): + default_tolerance = 1e-5 + tol_table = {} + + if isinstance(distribution, ( + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + distribute_lib._DefaultDistributionStrategy)): # pylint: disable=protected-access + # TODO(b/119257215): Weights are not exactly the same, so use larger + # tolerance for now. Predict should be related to weights. + tol_table = { + 'weights_1': 1e-4, + 'weights_2': 1e-4, + 'predict_result_1': 1e-4, + } + + keras.backend.set_image_data_format('channels_last') + np.random.seed(_RANDOM_SEED) + random_seed.set_random_seed(_RANDOM_SEED) + + # Train, eval, and predict datasets are created with the same input numpy + # arrays. + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + x_predict = [[1.], [2.], [3.], [4.]] + + # The model is built once and the initial weights are saved. + # This is used to initialize the model for both the distribution and + # non-distribution run. In addition, we add few non-linear layers to make + # it non-trivial. + def _create_model(): + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + return model + + model = _create_model() + initial_weights = model.get_weights() + del model # avoid accident usage. + + def fit_eval_and_predict(with_distribution=None): + model = _create_model() + # We have initialized the model to the same weight for the distribution + # and non-distribution run. + model.set_weights(initial_weights) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse'], + distribute=with_distribution) + + training_inputs, eval_inputs, predict_inputs = ( + get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, + x_train, y_train, x_predict)) + + result = {} + result['training_history_1'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_1'] = model.evaluate(**eval_inputs) + + result['weights_1'] = model.get_weights() + result['predict_result_1'] = model.predict(**predict_inputs) + + # Train and eval again to mimic user's flow. + + result['training_history_2'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_2'] = model.evaluate(**eval_inputs) + + result['weights_2'] = model.get_weights() + + return result + + results_with_ds = fit_eval_and_predict(with_distribution=distribution) + results_without_ds = fit_eval_and_predict(with_distribution=None) + + # Verify that the weights, training history, eval results, predict outputs + # are the same within some limits of tolerance. + for key in results_with_ds: + if (key.startswith('training_history') and + isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the + # underlying bug is fixed. + continue + + tolerance = tol_table.get(key, default_tolerance) + + self.assertAllClose( + results_with_ds[key], + results_without_ds[key], + atol=tolerance, + rtol=tolerance, + msg='Fail to assert {}.'.format(key)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_correctness_test_base.py b/tensorflow/contrib/distribute/python/keras_correctness_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb7a18c40484ce01a5acfd6b191de464cfd9840 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_correctness_test_base.py @@ -0,0 +1,487 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import random_seed +from tensorflow.python.keras.engine import distributed_training_utils + +_RANDOM_SEED = 1337 +_EVAL_STEPS = 20 +_GLOBAL_BATCH_SIZE = 64 + +# Note: Please make sure the tests in this file are also covered in +# keras_backward_compat_test for features that are supported with both APIs. + + +all_strategies = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus, + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step, +] + + +def eager_mode_test_configuration(): + return combinations.combine(mode='eager', + use_numpy=False, + use_validation_data=False) + + +def graph_mode_test_configuration(): + return combinations.combine(mode='graph', + use_numpy=[True, False], + use_validation_data=[True, False]) + + +def all_strategy_and_input_config_combinations(): + return ( + combinations.times( + combinations.combine(distribution=all_strategies), + eager_mode_test_configuration() + graph_mode_test_configuration())) + + +def strategies_for_embedding_models(): + """Returns distribution strategies to test for embedding models. + + Since embedding models take longer to train, we disregard OneDeviceStrategy + and DefaultStrategy in order to prevent testing timeouts. + """ + + return [s for s in all_strategies if s.required_tpu or s.required_gpus] + + +def test_combinations_for_embedding_model(): + return ( + combinations.times( + combinations.combine(distribution= + strategies_for_embedding_models()), + (graph_mode_test_configuration() + + eager_mode_test_configuration()))) + + +def test_combinations_with_tpu_strategies(): + tpu_strategies = [combinations.tpu_strategy, + combinations.tpu_strategy_one_step] + + return ( + combinations.times( + combinations.combine(distribution=tpu_strategies), + graph_mode_test_configuration())) + + +class MaybeDistributionScope(object): + """Provides a context allowing no distribution strategy.""" + + def __init__(self, distribution): + self._distribution = distribution + self._scope = None + + def __enter__(self): + if self._distribution: + self._scope = self._distribution.scope() + self._scope.__enter__() + + def __exit__(self, exc_type, value, traceback): + if self._distribution: + self._scope.__exit__(exc_type, value, traceback) + self._scope = None + + +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) + # TPUs currently require fully defined input shapes, drop_remainder ensures + # the input will have fully defined shapes. + if isinstance(distribution, tpu_strategy.TPUStrategy): + return dataset.batch(batch_size, drop_remainder=True) + else: + return dataset.batch(batch_size) + + +def get_batch_size(global_batch_size, distribution): + batch_size = global_batch_size + # TODO(b/118776054): Use global batch size for Keras/DS support. + use_per_core_batch_size = ( + distribution and + not distributed_training_utils.global_batch_size_supported(distribution)) + if use_per_core_batch_size: + batch_size //= distribution.num_replicas_in_sync + return batch_size + + +def get_data_size(data): + """Gets the size of data in list, tuple, dict, or a numpy array.""" + assert isinstance(data, (np.ndarray, list, dict, tuple)) + + if isinstance(data, np.ndarray): + return len(data) + + if isinstance(data, (list, tuple)): + return len(data[0]) + + return len(six.next(six.itervalues(data))) + + +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, x_train, y_train, x_predict): + """Generates the inputs for correctness check when enable Keras with DS.""" + training_epochs = 2 + global_batch_size = _GLOBAL_BATCH_SIZE + batch_size = get_batch_size(global_batch_size, with_distribution) + + if use_numpy: + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': training_epochs, + 'shuffle': False, + } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } + predict_inputs = { + 'x': x_predict + } + else: + training_data_size = get_data_size(x_train) + if training_data_size < _GLOBAL_BATCH_SIZE * _EVAL_STEPS: + # Currently, we cannot detect the size of a dataset. So, the eval steps is + # hard coded. + raise ValueError('x_train must have at least ' + '_GLOBAL_BATCH_SIZE * _EVAL_STEPS samples') + # For dataset inputs, we do not pass batch_size to + # keras.fit/evaluate/predict. The batch size is part of the dataset. + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + x = batch_wrapper(train_dataset, batch_size, with_distribution, + repeat=training_epochs) + + training_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'epochs': training_epochs, + 'shuffle': False, + 'steps_per_epoch': training_data_size // global_batch_size, + } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': _EVAL_STEPS, + } + + predict_batch_size = get_batch_size(get_data_size(x_predict), + with_distribution) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) + predict_dataset = batch_wrapper(predict_dataset, predict_batch_size, + with_distribution) + predict_inputs = { + 'steps': 1, + 'x': predict_dataset, + } + + return training_inputs, eval_inputs, predict_inputs + + +def fit_eval_and_predict(initial_weights, input_fn, model_fn, + distribution=None, is_stateful_model=False): + """Generates results for fit/predict/evaluate for given model.""" + model = model_fn(initial_weights=initial_weights, distribution=distribution) + training_inputs, eval_inputs, predict_inputs = input_fn(distribution) + + result = {} + result['training_history_1'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_1'] = model.evaluate(**eval_inputs) + + result['weights_1'] = model.get_weights() + + if predict_inputs is not None: + # Check correctness of the result of predict() invoked + # multiple times -- as for stateful models, result of + # predict may differ for each batch. + predict_length = 1 + if is_stateful_model: + predict_length = 3 + for i in range(predict_length): + result_key = 'predict_result_{}'.format(i) + result[result_key] = model.predict(**predict_inputs) + + # Train and eval again to mimic user's flow. + + result['training_history_2'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_2'] = model.evaluate(**eval_inputs) + + result['weights_2'] = model.get_weights() + + return result + + +def compare_results(results_with_ds, results_without_ds, distribution, + testcase): + """Compares results of model compiled with/without distribution strategy.""" + + default_tolerance = 1e-5 + relaxed_tolerance = 1e-4 + + def _get_compare_result_tolerance(key): + """Returns tolerance to compare results.""" + # TODO(b/119257215): For MirroredStrategy, weights are not exactly the same, + # so use larger tolerance for now. Predict should be related to weights. + if (isinstance(distribution, ( + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + distribute_lib._DefaultDistributionStrategy)) and # pylint: disable=protected-access + key.startswith(('weights_1', 'weights_2', 'predict_result'))): + return relaxed_tolerance + + return default_tolerance + + for key in results_with_ds: + if (key.startswith('training_history') and + isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the + # underlying bug is fixed. + continue + + tolerance = _get_compare_result_tolerance(key) + testcase.assertAllClose( + results_with_ds[key], + results_without_ds[key], + atol=tolerance, + rtol=tolerance, + msg='Fail to assert {}.'.format(key)) + + +def should_skip_tpu_with_eager(distribution): + return (context.executing_eagerly() and + isinstance(distribution, tpu_strategy.TPUStrategy)) + + +class LearningRateBatchScheduler(keras.callbacks.Callback): + """Scheduler that dynamically sets the learning rate of model.""" + + def __init__(self, update_freq=None): + self._update_freq = update_freq + + def on_batch_begin(self, batch, logs=None): + if self._update_freq and batch % self._update_freq != 0: + return + + # To avoid divergence, limit the value range. + lr = 0.001 * (batch % 10) + keras.backend.set_value(self.model.optimizer.lr, lr) + + +class TestDistributionStrategyCorrectnessBase(test.TestCase, + parameterized.TestCase): + """Model agnostic testing infra to test correctness of Keras models.""" + + def set_up_test_config(self, use_numpy=False, + use_validation_data=False, + with_batch_norm=False): + self.use_numpy = use_numpy + self.use_validation_data = use_validation_data + self.with_batch_norm = with_batch_norm + + keras.backend.set_image_data_format('channels_last') + np.random.seed(_RANDOM_SEED) + random_seed.set_random_seed(_RANDOM_SEED) + + def get_data(self): + num_samples = 10000 + x_train = np.random.randint(0, 2, num_samples) + x_train = np.reshape(x_train, (num_samples, 1)) + y_train = x_train + return (x_train.astype('float32'), y_train.astype('float32'), None) + + def get_model(self, distribution=None): + raise NotImplementedError + + def skip_unsupported_test_configuration(self, distribution): + if should_skip_tpu_with_eager(distribution): + self.skipTest('TPUStrategy does not support eager mode now.') + + if context.executing_eagerly() and self.use_numpy: + self.skipTest('Numpy as inputs is not supported with strategy in eager.') + + if context.executing_eagerly() and self.use_validation_data: + self.skipTest('TODO(hongjunchoi): Add test logic for using validation ' + 'data for eager execution.') + return + + def run_correctness_test(self, + distribution, + use_numpy, + use_validation_data, + with_batch_norm=False, + is_stateful_model=False): + with self.cached_session(): + self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm) + self.skip_unsupported_test_configuration(distribution) + + # Train, eval, and predict datasets are created with the same input numpy + # arrays. + x_train, y_train, x_predict = self.get_data() + + # The model is built once and the initial weights are saved. + # This is used to initialize the model for both the distribution and + # non-distribution run. + model = self.get_model() + initial_weights = model.get_weights() + + def input_fn(dist): + return get_correctness_test_inputs( + use_numpy, use_validation_data, dist, x_train, y_train, x_predict) + + results_with_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=distribution, is_stateful_model=is_stateful_model) + results_without_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=None, is_stateful_model=is_stateful_model) + + # First, special case, for multi-replica distributed training, batch norm + # is not aggregated globally. So it is expected to have different weights. + if (self.with_batch_norm and + distribution.num_replicas_in_sync > 1): + with self.assertRaises(AssertionError): + compare_results(results_with_ds, results_without_ds, distribution, + testcase=self) + else: + compare_results(results_with_ds, results_without_ds, distribution, + testcase=self) + + def run_dynamic_lr_test(self, distribution): + with self.cached_session(): + self.set_up_test_config() + self.skip_unsupported_test_configuration(distribution) + + x_train, y_train, _ = self.get_data() + model = self.get_model() + initial_weights = model.get_weights() + update_freq = None + + if (isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # For TPUStrategy with steps_per_run > 1, the callback is not invoked + # every step. So, to compare the CPU/TPU, we let the CPU to behave the + # same as TPU. + update_freq = distribution.extended.steps_per_run + + def input_fn(dist): + """Generates training test given test configuration.""" + training_epochs = 2 + global_batch_size = 64 + batch_size = get_batch_size(global_batch_size, dist) + + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': training_epochs, + 'shuffle': False, + 'callbacks': [LearningRateBatchScheduler(update_freq)], + 'validation_data': (x_train, y_train) + } + # In this test case, we do not care eval and predict. + eval_inputs, predict_inputs = None, None + return training_inputs, eval_inputs, predict_inputs + + results_with_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=distribution) + results_without_ds = fit_eval_and_predict( + initial_weights, input_fn=input_fn, model_fn=self.get_model, + distribution=None) + compare_results(results_with_ds, results_without_ds, distribution, + testcase=self) + + +class TestDistributionStrategyEmbeddingModelCorrectnessBase( + TestDistributionStrategyCorrectnessBase): + """Base class to test correctness of Keras models with embedding layers.""" + + def get_data(self, + count=(_GLOBAL_BATCH_SIZE * _EVAL_STEPS), + min_words=5, + max_words=10, + max_word_id=19, + num_classes=2): + distribution = [] + for _ in range(num_classes): + dist = np.abs(np.random.randn(max_word_id)) + dist /= np.sum(dist) + distribution.append(dist) + + features = [] + labels = [] + for _ in range(count): + label = np.random.randint(0, num_classes, size=1)[0] + num_words = np.random.randint(min_words, max_words, size=1)[0] + word_ids = np.random.choice( + max_word_id, size=num_words, replace=True, p=distribution[label]) + word_ids = word_ids + labels.append(label) + features.append(word_ids) + + features = keras.preprocessing.sequence.pad_sequences( + features, maxlen=max_words) + x_train = np.asarray(features, dtype=np.float32) + y_train = np.asarray(labels, dtype=np.int32).reshape((count, 1)) + x_predict = x_train[:_GLOBAL_BATCH_SIZE] + return x_train, y_train, x_predict + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py b/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dae32188917cce9209b8e51032ef808352bc257c --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py @@ -0,0 +1,171 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras DNN model using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.training import gradient_descent + + +def all_strategy_combinations_with_eager_and_graph_modes(): + return combinations.combine(distribution=keras_correctness_test_base. + all_strategies, + mode=['graph', 'eager']) + + +def all_strategy_combinations_with_graph_mode(): + return combinations.combine(distribution=keras_correctness_test_base. + all_strategies, mode=['graph']) + + +class TestDistributionStrategyDnnCorrectness( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + # We add few non-linear layers to make it non-trivial. + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse']) + return model + + def get_data(self): + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32) + return x_train, y_train, x_predict + + @combinations.generate(keras_correctness_test_base. + all_strategy_and_input_config_combinations()) + def test_dnn_correctness(self, distribution, use_numpy, use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + @combinations.generate(all_strategy_combinations_with_graph_mode()) + def test_dnn_with_dynamic_learning_rate(self, distribution): + self.run_dynamic_lr_test(distribution) + + +class TestDistributionStrategyDnnMetricCorrectness( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, distribution=None): + with distribution.scope(): + model = keras.Sequential() + model.add(keras.layers.Dense(1, + input_shape=(1,), + kernel_initializer='ones')) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + metrics=[keras.metrics.BinaryAccuracy()]) + return model + + def run_metric_correctness_test(self, distribution): + with self.cached_session(): + self.set_up_test_config() + self.skip_unsupported_test_configuration(distribution) + + x_train, y_train, _ = self.get_data() + model = self.get_model(distribution=distribution) + + batch_size = 64 + batch_size = (keras_correctness_test_base. + get_batch_size(batch_size, distribution)) + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + train_dataset = (keras_correctness_test_base. + batch_wrapper(train_dataset, batch_size, distribution)) + + history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) + + @combinations.generate(all_strategy_combinations_with_eager_and_graph_modes()) + def test_simple_dnn_metric_correctness(self, distribution): + self.run_metric_correctness_test(distribution) + + +class TestDistributionStrategyDnnMetricEvalCorrectness( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, distribution=None): + with distribution.scope(): + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', keras.metrics.BinaryAccuracy()], + optimizer=gradient_descent.GradientDescentOptimizer(0.001)) + return model + + def run_eval_metrics_correctness_test(self, distribution): + with self.cached_session(): + self.set_up_test_config() + self.skip_unsupported_test_configuration(distribution) + + model = self.get_model(distribution=distribution) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)).astype('float32') + y = np.ones((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = (keras_correctness_test_base. + batch_wrapper(dataset, 4, distribution)) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = (keras_correctness_test_base. + batch_wrapper(dataset, 4, distribution)) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @combinations.generate(all_strategy_combinations_with_eager_and_graph_modes()) + def test_identity_model_metric_eval_correctness(self, distribution): + self.run_eval_metrics_correctness_test(distribution) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e881bb70ecc428e3f972cde5f19c1b61b1dc0f0b --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_embedding_model_correctness_test.py @@ -0,0 +1,150 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness test for tf.keras Embedding models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +class DistributionStrategyEmbeddingModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words') + word_embed = keras.layers.Embedding(input_dim=20, + output_dim=10)(word_ids) + if self.use_distributed_dense: + word_embed = keras.layers.TimeDistributed(keras.layers.Dense(4))( + word_embed) + avg = keras.layers.GlobalAveragePooling1D()(word_embed) + preds = keras.layers.Dense(2, activation='softmax')(avg) + model = keras.Model(inputs=[word_ids], outputs=[preds]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + return model + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_embedding_model_correctness(self, distribution, use_numpy, + use_validation_data): + + self.use_distributed_dense = False + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_embedding_time_distributed_model_correctness(self, + distribution, + use_numpy, + use_validation_data): + self.use_distributed_dense = True + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + +class DistributionStrategySiameseEmbeddingModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids_a = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words_a') + word_ids_b = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words_b') + + def submodel(embedding, word_ids): + word_embed = embedding(word_ids) + rep = keras.layers.GlobalAveragePooling1D()(word_embed) + return keras.Model(inputs=[word_ids], outputs=[rep]) + + word_embed = keras.layers.Embedding( + input_dim=20, + output_dim=10, + input_length=max_words, + embeddings_initializer=keras.initializers.RandomUniform(0, 1)) + + a_rep = submodel(word_embed, word_ids_a).outputs[0] + b_rep = submodel(word_embed, word_ids_b).outputs[0] + sim = keras.layers.Dot(axes=1, normalize=True)([a_rep, b_rep]) + + model = keras.Model(inputs=[word_ids_a, word_ids_b], outputs=[sim]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='mse', + metrics=['mse']) + return model + + def get_data(self, + count=(keras_correctness_test_base._GLOBAL_BATCH_SIZE * + keras_correctness_test_base._EVAL_STEPS), + min_words=5, + max_words=10, + max_word_id=19, + num_classes=2): + features_a, labels_a, _ = (super( + DistributionStrategySiameseEmbeddingModelCorrectnessTest, self). + get_data(count, min_words, max_words, + max_word_id, num_classes)) + + features_b, labels_b, _ = (super( + DistributionStrategySiameseEmbeddingModelCorrectnessTest, self). + get_data(count, min_words, max_words, + max_word_id, num_classes)) + + y_train = np.zeros((count, 1), dtype=np.float32) + y_train[labels_a == labels_b] = 1.0 + y_train[labels_a != labels_b] = -1.0 + # TODO(b/123360757): Add tests for using list as inputs for multi-input + # models. + x_train = { + 'words_a': features_a, + 'words_b': features_b, + } + x_predict = x_train + + return x_train, y_train, x_predict + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_siamese_embedding_model_correctness(self, distribution, use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f625664372dfb6814ccbe9539f6abe018d2a4447 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py @@ -0,0 +1,92 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras CNN models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +class DistributionStrategyCnnCorrectnessTest( + keras_correctness_test_base.TestDistributionStrategyCorrectnessBase): + + def get_model(self, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + image = keras.layers.Input(shape=(28, 28, 3), name='image') + c1 = keras.layers.Conv2D( + name='conv1', filters=16, kernel_size=(3, 3), strides=(4, 4))( + image) + if self.with_batch_norm: + c1 = keras.layers.BatchNormalization(name='bn1')(c1) + c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1) + logits = keras.layers.Dense( + 10, activation='softmax', name='pred')( + keras.layers.Flatten()(c1)) + model = keras.Model(inputs=[image], outputs=[logits]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + + return model + + def get_data(self, + count=keras_correctness_test_base._GLOBAL_BATCH_SIZE + * keras_correctness_test_base._EVAL_STEPS, + shape=(28, 28, 3), + num_classes=10): + centers = np.random.randn(num_classes, *shape) + + features = [] + labels = [] + for _ in range(count): + label = np.random.randint(0, num_classes, size=1)[0] + offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape)) + offset = offset.reshape(shape) + labels.append(label) + features.append(centers[label] + offset) + + x_train = np.asarray(features, dtype=np.float32) + y_train = np.asarray(labels, dtype=np.float32).reshape((count, 1)) + x_predict = x_train + return x_train, y_train, x_predict + + @combinations.generate(keras_correctness_test_base. + all_strategy_and_input_config_combinations()) + def test_cnn_correctness(self, distribution, use_numpy, use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + @combinations.generate(keras_correctness_test_base. + all_strategy_and_input_config_combinations()) + def test_cnn_with_batch_norm_correctness(self, distribution, use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + with_batch_norm=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_lstm_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_lstm_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed2dfa206cdf4be24a88b1d54090487c1873399 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_lstm_model_correctness_test.py @@ -0,0 +1,65 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Correctness tests for tf.keras LSTM model using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +class DistributionStrategyLstmModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids = keras.layers.Input( + shape=(max_words,), dtype=np.int32, name='words') + word_embed = keras.layers.Embedding(input_dim=20, + output_dim=10)(word_ids) + lstm_embed = keras.layers.LSTM(units=4, + return_sequences=False)(word_embed) + + preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) + model = keras.Model(inputs=[word_ids], outputs=[preds]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + return model + + @combinations.generate(keras_correctness_test_base. + test_combinations_for_embedding_model()) + def test_lstm_model_correctness(self, + distribution, + use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index 6dfd85bcc4f3784e2744fd876a7190cc9581d96a..952b11932b83d16558ac9f5ce780886d94e72744 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -18,24 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import shutil -import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib.distribute.python import combinations -from tensorflow.core.protobuf import config_pb2 from tensorflow.python import keras -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context -from tensorflow.python.estimator import run_config -from tensorflow.python.estimator import training -from tensorflow.python.estimator.canned import dnn_linear_combined -from tensorflow.python.estimator.canned import prediction_keys -from tensorflow.python.estimator.export import export -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column_lib as feature_column +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -44,103 +33,7 @@ from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import test -from tensorflow.python.summary.writer import writer_cache - - -class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def dataset_input_fn(self, x, y, batch_size): - - def input_fn(): - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - dataset = dataset.repeat(1).batch(batch_size) - return dataset - - return input_fn - - @combinations.generate( - combinations.combine( - mode=['graph'], - distribution=[ - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus - ], - use_train_and_evaluate=[True, False])) - def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): - label_dimension = 2 - input_dimension = label_dimension - batch_size = 10 - data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) - data = data.reshape(batch_size, label_dimension) - train_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync) - eval_input_fn = self.dataset_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size // distribution.num_replicas_in_sync) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, batch_size=batch_size, shuffle=False) - - linear_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - dnn_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) - ] - feature_columns = linear_feature_columns + dnn_feature_columns - session_config = config_pb2.ConfigProto( - log_device_placement=True, allow_soft_placement=True) - estimator = dnn_linear_combined.DNNLinearCombinedRegressor( - linear_feature_columns=linear_feature_columns, - dnn_hidden_units=(2, 2), - dnn_feature_columns=dnn_feature_columns, - label_dimension=label_dimension, - model_dir=self._model_dir, - dnn_optimizer=adam.Adam(0.001), - linear_optimizer=adam.Adam(0.001), - config=run_config.RunConfig( - train_distribute=distribution, - eval_distribute=distribution, - session_config=session_config)) - - num_steps = 2 - if use_train_and_evaluate: - scores, _ = training.train_and_evaluate( - estimator, training.TrainSpec(train_input_fn, max_steps=num_steps), - training.EvalSpec(eval_input_fn)) - else: - estimator.train(train_input_fn, steps=num_steps) - scores = estimator.evaluate(eval_input_fn) - - self.assertIn('loss', six.iterkeys(scores)) - - predictions = np.array([ - x[prediction_keys.PredictionKeys.PREDICTIONS] - for x in estimator.predict(predict_input_fn) - ]) - self.assertAllEqual((batch_size, label_dimension), predictions.shape) - - feature_spec = feature_column.make_parse_example_spec(feature_columns) - serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( - feature_spec) - export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), - serving_input_receiver_fn) - self.assertTrue(gfile.Exists(export_dir)) - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) def get_model(): @@ -152,113 +45,80 @@ def get_model(): class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + @combinations.generate( + combinations.combine( + distribution=[ + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus, + combinations.parameter_server_strategy_with_two_gpus, + ], + mode=['graph', 'eager'])) def testKerasOptimizerWithUnequalInput(self, distribution): - def create_fn(): + with distribution.scope(): var = variables.Variable( 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM) - # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5. - loss = math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2) - train_op = optimizer.minimize(loss, var_list=[var]) - m = optimizer.get_slot(var, 'm') - v = optimizer.get_slot(var, 'v') - return (var, m, v, train_op, optimizer.iterations) + all_vars = [] - devices = ['/device:GPU:0', '/device:CPU:0'] - with distribution.scope(): - (var, m, v, op, counter) = distribution.call_for_each_replica(create_fn) + def model_fn(): + + def loss_fn(): + replica_id = _replica_id() + return math_ops.cast(replica_id + 1, dtype=dtypes.float32) * var + + train_op = optimizer.minimize(loss_fn, var_list=[var]) + + return train_op, optimizer + + def train_fn(): + train_op, optimizer = distribution.extended.call_for_each_replica( + model_fn) + if not all_vars: + all_vars.append(var) + all_vars.append(optimizer.get_slot(var, 'm')) + all_vars.append(optimizer.get_slot(var, 'v')) + return distribution.group(train_op) + + if not context.executing_eagerly(): + with self.cached_session() as sess: + train_fn = sess.make_callable(train_fn()) self.evaluate(variables.global_variables_initializer()) - var_val = [2.0, 2.0, 2.0] - self.assertAllClose( - var_val, - self.evaluate( - [distribution.read_var(var), - var.get(devices[0]), - var.get(devices[1])])) - self.assertAllClose([0, 0, 0], - self.evaluate([ - distribution.read_var(counter), - counter.get(devices[0]), - counter.get(devices[1]) - ])) - train_op = distribution.unwrap(op) - self.evaluate(train_op) - # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2 - m_val = [1.2, 1.2, 1.2] - # assert slot variables in both replicas are the same. - self.assertAllClose( - m_val, - self.evaluate( - [distribution.read_var(m), - m.get(devices[0]), - m.get(devices[1])])) - # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 - v_val = [1.8, 1.8, 1.8] - self.assertAllClose( - v_val, - self.evaluate( - [distribution.read_var(v), - v.get(devices[0]), - v.get(devices[1])])) + # first step. + train_fn() # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1) # = 2.0 - 0.01 * 1.2 * sqrt(0.8) / sqrt(1.8) / 0.8 - var_val = [1.99, 1.99, 1.99] - self.assertAllClose( - var_val, - self.evaluate( - [distribution.read_var(var), - var.get(devices[0]), - var.get(devices[1])])) - self.assertAllClose([1, 1, 1], - self.evaluate([ - distribution.read_var(counter), - counter.get(devices[0]), - counter.get(devices[1]) - ])) + self.assertAllClose(1.99, self.evaluate(all_vars[0])) + # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2 + self.assertAllClose(1.2, self.evaluate(all_vars[1])) + # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 + self.assertAllClose(1.8, self.evaluate(all_vars[2])) - self.evaluate(train_op) + # second step. + train_fn() + # var(1) = var(0) - lr * 2 = 1.98 + self.assertAllClose(1.98, self.evaluate(all_vars[0])) # m(2) = beta1 * m(1) + (1-beta1) * grad = 0.2 * 1.2 + 0.8 * 1.5 - m_val = [1.44, 1.44, 1.44] - self.assertAllClose( - m_val, - self.evaluate( - [distribution.read_var(m), - m.get(devices[0]), - m.get(devices[1])])) + self.assertAllClose(1.44, self.evaluate(all_vars[1])) # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25 - v_val = [2.16, 2.16, 2.16] - self.assertAllClose( - v_val, - self.evaluate( - [distribution.read_var(v), - v.get(devices[0]), - v.get(devices[1])])) - self.assertAllClose([2, 2, 2], - self.evaluate([ - distribution.read_var(counter), - counter.get(devices[0]), - counter.get(devices[1]) - ])) + self.assertAllClose(2.16, self.evaluate(all_vars[2])) - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + @combinations.generate( + combinations.combine( + distribution=[ + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.parameter_server_strategy_with_two_gpus, + ], + mode=['graph', 'eager'])) def testOptimizerWithKerasModelAndNumpyArrays(self, distribution): with self.cached_session(): - model = get_model() - optimizer = gradient_descent.SGD(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.SGD(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) inputs = np.zeros((64, 3), dtype=np.float32) targets = np.zeros((64, 4), dtype=np.float32) diff --git a/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f5faf6c36b880a72bafc8d082cff2816f3b11a76 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_stateful_lstm_model_correctness_test.py @@ -0,0 +1,99 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for stateful tf.keras LSTM models using DistributionStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_correctness_test_base +from tensorflow.python import keras +from tensorflow.python.eager import test +from tensorflow.python.training import gradient_descent + + +def strategies_for_stateful_embedding_model(): + """Returns TPUStrategy with single core device assignment.""" + + return [combinations.tpu_strategy_one_core, + combinations.tpu_strategy_one_step_one_core] + + +def test_combinations_for_stateful_embedding_model(): + return ( + combinations.combine( + distribution=strategies_for_stateful_embedding_model(), + mode='graph', + use_numpy=False, + use_validation_data=False + )) + + +class DistributionStrategyStatefulLstmModelCorrectnessTest( + keras_correctness_test_base. + TestDistributionStrategyEmbeddingModelCorrectnessBase): + + def get_model(self, max_words=10, initial_weights=None, distribution=None): + batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE + + with keras_correctness_test_base.MaybeDistributionScope(distribution): + word_ids = keras.layers.Input( + shape=(max_words,), + batch_size=batch_size, + dtype=np.int32, name='words') + word_embed = keras.layers.Embedding(input_dim=20, + output_dim=10)(word_ids) + lstm_embed = keras.layers.LSTM(units=4, + return_sequences=False, + stateful=True)(word_embed) + + preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) + model = keras.Model(inputs=[word_ids], outputs=[preds]) + + if initial_weights: + model.set_weights(initial_weights) + + model.compile( + optimizer=gradient_descent.GradientDescentOptimizer( + learning_rate=0.1), + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy']) + return model + + @combinations.generate(test_combinations_for_stateful_embedding_model()) + def test_stateful_lstm_model_correctness(self, + distribution, + use_numpy, + use_validation_data): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + is_stateful_model=True) + + @combinations.generate(keras_correctness_test_base. + test_combinations_with_tpu_strategies()) + def test_incorrectly_use_multiple_cores_for_stateful_lstm_model( + self, distribution, use_numpy, use_validation_data): + with self.assertRaisesRegexp(ValueError, + 'Single core must be used for computation ' + 'on stateful models. Consider adding ' + '`device_assignment` parameter to ' + 'TPUStrategy'): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + is_stateful_model=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 1d002819745f1959b535ffa534be8f1a6b93b31d..dd975c6c36d5d5387035e9da4170e4072406d79c 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -17,7 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import os +import tempfile from absl.testing import parameterized import numpy as np @@ -32,7 +34,6 @@ from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils @@ -48,6 +49,9 @@ _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 +# Note: Please make sure the tests in this file are also covered in +# keras_backward_compat_test for features that are supported with both APIs. + # TODO(anjalisridhar): Add a decorator that will allow us to run these tests as # part of the tf.keras unit tests suite. @@ -68,6 +72,18 @@ def simple_functional_model(): return model +def simple_multi_inputs_multi_outputs_model(): + input_a = keras.layers.Input(shape=(16,), name='input_a') + input_b = keras.layers.Input(shape=(16,), name='input_b') + + merged = keras.layers.concatenate([input_a, input_b], name='merge') + output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged) + output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged) + model = keras.models.Model( + inputs=[input_a, input_b], outputs=[output_c, output_d]) + return model + + def multi_inputs_multi_outputs_model(): input_a = keras.layers.Input(shape=(16,), name='input_a') input_b = keras.layers.Input(shape=(16,), name='input_b') @@ -165,7 +181,9 @@ def get_multi_inputs_multi_outputs_data(): return (train_data, test_data) -def batch_wrapper(dataset, batch_size, distribution): +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) # TPUs currently require fully defined input shapes, drop_remainder ensures # the input will have fully defined shapes. if isinstance(distribution, tpu_strategy.TPUStrategy): @@ -212,111 +230,65 @@ def multi_input_output_model(): return model -def get_correctness_test_inputs(use_numpy, with_distribution, - x_train, y_train, x_predict): - """Generates the inputs for correctness check when enable Keras with DS.""" - global_batch_size = 64 - batch_size = global_batch_size - # TODO(b/118776054): Use global batch size for Keras/DS support. - use_per_core_batch_size = ( - with_distribution and - not distributed_training_utils.global_batch_size_supported( - with_distribution)) - if use_per_core_batch_size: - batch_size //= with_distribution.num_replicas_in_sync - - if use_numpy: - training_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - 'epochs': 1, - 'shuffle': False, - } - eval_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - } - predict_inputs = { - 'x': np.array(x_predict, dtype=np.float32), - } - else: - # For dataset inputs, we do not pass batch_size to - # keras.fit/evaluate/predict. The batch size is part of the dataset. - train_dataset = dataset_ops.Dataset.from_tensor_slices( - (x_train, y_train)) - x = batch_wrapper(train_dataset, batch_size, with_distribution) - - training_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'epochs': 1, - 'shuffle': False, - 'steps_per_epoch': len(x_train) // global_batch_size, - } - eval_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'steps': 20, - } - predict_batch_size = len(x_predict) - if use_per_core_batch_size: - predict_batch_size //= with_distribution.num_replicas_in_sync - predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) - predict_dataset = batch_wrapper(predict_dataset, - predict_batch_size, with_distribution) - predict_inputs = { - 'steps': 1, - 'x': predict_dataset, - } - - return training_inputs, eval_inputs, predict_inputs - - -strategies = [combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus, - combinations.tpu_strategy, # steps_per_run=2 - combinations.tpu_strategy_one_step] +strategies_minus_tpu = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus] + +tpu_strategies = [ + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step] def strategy_minus_tpu_combinations(): return combinations.combine( - distribution=[combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph']) + distribution=strategies_minus_tpu, + mode=['graph', 'eager']) -def strategy_combinations(): +def tpu_strategy_combinations(): return combinations.combine( - distribution=strategies, + distribution=tpu_strategies, mode=['graph']) -def strategy_and_optimizer_combinations(): - return combinations.combine( - distribution=strategies, - optimizer=[combinations.adagrad_optimizer_v1_fn, - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.rmsprop_optimizer_v1_fn], - mode=['graph']) +def all_strategy_combinations(): + return strategy_minus_tpu_combinations() + tpu_strategy_combinations() -def strategy_and_inputs(): +def all_strategy_combinations_minus_default(): + strategy_minus_default_combinations = combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager']) + return strategy_minus_default_combinations + tpu_strategy_combinations() + + +def strategy_and_optimizer_combinations(): + return combinations.times( + all_strategy_combinations(), + combinations.combine(optimizer=[ + combinations.adagrad_optimizer_v1_fn, + combinations.adagrad_optimizer_keras_v2_fn, + combinations.adam_optimizer_v1_fn, + combinations.adam_optimizer_keras_v2_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_keras_v2_fn, + combinations.rmsprop_optimizer_v1_fn, + combinations.rmsprop_optimizer_keras_v2_fn + ])) + + +def strategy_for_numpy_input_combinations(): return combinations.combine( - distribution=strategies, - use_numpy=[True, False], + distribution=strategies_minus_tpu + tpu_strategies, mode=['graph']) @@ -337,7 +309,9 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph'])) def test_train_functional_with_distribution_strategy(self, distribution): @@ -365,7 +339,9 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=['graph'])) def test_train_sequential_with_distribution_strategy(self, distribution): @@ -392,8 +368,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph'])) def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution): train_data, test_data = get_multi_inputs_multi_outputs_data() @@ -444,8 +420,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=['graph'])) def test_keras_optimizer_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() @@ -471,16 +447,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) - def test_creating_var_with_numpy_arrays(self, distribution): - with self.cached_session(): - x = np.asarray(np.random.random((64, 3)), dtype=np.float32) - var_x = distributed_training_utils.get_var_for_numpy(distribution, x) - val = self.evaluate(var_x.value()) - # Verify that the numpy value is copied to the variable. - self.assertAllEqual(x, val) - - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calculating_input_params_no_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies # that use per_core_batch_size @@ -511,7 +478,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_63_samples, steps=None, batch_size=None) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calculating_input_params_with_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -557,7 +524,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_63_samples, steps=1, batch_size=None) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calculating_input_params_no_steps_with_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -591,7 +558,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_64_samples, steps=None, batch_size=3) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calculating_input_params_with_steps_with_batch_size(self, distribution): with self.cached_session(): @@ -608,45 +575,46 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_64_samples, steps=10, batch_size=13) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - - inputs = np.zeros((64, 3), dtype=np.float32) - targets = np.zeros((64, 4), dtype=np.float32) - - # Call fit with validation data - model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, - validation_data=(inputs, targets)) - - # TODO(anjalisridhar): We need tests for when the batch size and steps are - # smaller and results in a 0 batch_size and steps value. - model.evaluate(inputs, targets) - # with steps - model.evaluate(inputs, targets, steps=2) - # with batch_size - model.evaluate(inputs, targets, batch_size=8) - - model.predict(inputs) - # with steps - model.predict(inputs, steps=2) - # with batch_size - model.predict(inputs, batch_size=8) - - @combinations.generate(strategy_combinations()) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) + + # TODO(anjalisridhar): We need tests for when the batch size and steps + # are smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): with self.cached_session(): - model = multi_input_output_model() - - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = multi_input_output_model() + optimizer = gradient_descent.GradientDescentOptimizer( + learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) input_b_np = np.asarray(np.random.random((64, 5)), dtype=np.float32) @@ -673,28 +641,32 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(strategy_minus_tpu_combinations()) + @combinations.generate(combinations.combine( + distribution=strategies_minus_tpu, mode=['graph'])) def test_numpy_with_sample_weights(self, distribution): - model = get_model() - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) - inputs = np.zeros((20, 3), np.float32) - targets = np.zeros((20, 4), np.float32) - sample_weights = np.ones((20), np.float32) + inputs = np.zeros((20, 3), np.float32) + targets = np.zeros((20, 4), np.float32) + sample_weights = np.ones((20), np.float32) - model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, - steps_per_epoch=2, verbose=1) + model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, + steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_flatten_predict_outputs(self, distribution): with self.cached_session(): - model = multi_input_output_model() - - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = multi_input_output_model() + optimizer = gradient_descent.GradientDescentOptimizer( + learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) # We take 6 input samples with each input having a dimension of 3 or 5. input_a_np = np.asarray(np.random.random((6, 3)), dtype=np.float32) @@ -711,19 +683,74 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, self.assertAllEqual([6, 7], outs[0].shape) self.assertAllEqual([6, 7], outs[1].shape) + @combinations.generate(tpu_strategy_combinations()) + def test_predict_with_partial_batch(self, distribution): + with self.cached_session(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + + with distribution.scope(): + model_with_ds_strategy = get_model() + model_with_ds_strategy.compile(optimizer, loss) + + cpu_model = get_model() + cpu_model.compile(optimizer, loss) + + inputs = np.zeros((10, 3), dtype=np.float32) + + # As sample size is 10, we batch by 4 so that the last batch is + # a partial batch. Also `fit()` using numpy array as inputs without + # distribution strategy uses entire sample as a single batch. As so, + # we remove parameters `batch_size` and `steps`. + cpu_model.set_weights(model_with_ds_strategy.get_weights()) + self.assertAllClose( + model_with_ds_strategy.predict(inputs, batch_size=4, steps=3), + cpu_model.predict(inputs), + atol=1e-5, rtol=1e-5) + + @combinations.generate(tpu_strategy_combinations()) + def test_predict_multi_output_model_with_partial_batch( + self, distribution): + with self.cached_session(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + + with distribution.scope(): + model_with_ds_strategy = simple_multi_inputs_multi_outputs_model() + model_with_ds_strategy.compile(optimizer, loss) + + cpu_model = simple_multi_inputs_multi_outputs_model() + cpu_model.compile(optimizer, loss) + + input_data, _ = get_multi_inputs_multi_outputs_data() + input_dict = { + 'input_a': input_data['input_a'], + 'input_b': input_data['input_b'], + } + + # As sample size is 200, we batch by 18 so that the last batch is + # a partial batch. Also `fit()` using numpy array as inputs without + # distribution strategy uses entire sample as a single batch. As so, + # we remove parameters `batch_size` and `steps`. + cpu_model.set_weights(model_with_ds_strategy.get_weights()) + self.assertAllClose( + model_with_ds_strategy.predict(input_dict, batch_size=18, steps=12), + cpu_model.predict(input_dict), + atol=1e-4, rtol=1e-4) + class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) dataset = get_dataset(distribution) @@ -734,23 +761,22 @@ class TestDistributionStrategyWithDatasets(test.TestCase, validation_data=dataset, validation_steps=2) model.predict(get_predict_dataset(distribution), steps=2) - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_model_interleaved_eval_same_as_direct_eval(self, distribution): with self.cached_session(): - user_controlled_model = get_model() - user_controlled_model.compile( - gradient_descent.GradientDescentOptimizer(0.001), - loss='mse', - metrics=['mae', keras.metrics.CategoricalAccuracy()], - distribute=distribution) - - interleaved_model = get_model() - interleaved_model.set_weights(user_controlled_model.get_weights()) - interleaved_model.compile( - gradient_descent.GradientDescentOptimizer(0.001), - loss='mse', - metrics=['mae', keras.metrics.CategoricalAccuracy()], - distribute=distribution) + with distribution.scope(): + user_controlled_model = get_model() + user_controlled_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()]) + + interleaved_model = get_model() + interleaved_model.set_weights(user_controlled_model.get_weights()) + interleaved_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()]) dataset = get_dataset(distribution) @@ -782,15 +808,16 @@ class TestDistributionStrategyWithDatasets(test.TestCase, distribution=[ combinations.mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + mode=['graph', 'eager'])) def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): with self.cached_session(): - model = multi_input_output_model() - - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = multi_input_output_model() + optimizer = gradient_descent.GradientDescentOptimizer( + learning_rate=0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) input_a_np = np.random.random((10, 3)) input_b_np = np.random.random((10, 5)) @@ -814,15 +841,51 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_combinations()) - def test_fit_eval_and_predict_methods_on_dataset(self, distribution): + # TODO(b/122743976): Include TPUStrategy for this test as well once + # step inference is supported. + @combinations.generate(strategy_minus_tpu_combinations()) + def test_fit_eval_and_predict_methods_on_dataset_without_steps( + self, distribution): with self.cached_session(): - model = get_model() + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((1000, 3), dtype=np.float32) + targets = np.zeros((1000, 4), dtype=np.float32) + # steps/steps_per_epoch are calculated when using numpy arrays as + # input data. + fit_with_numpy = model.fit(inputs, targets, epochs=1, + batch_size=10).history + eval_with_numpy = model.evaluate(inputs, targets, batch_size=10) + predict_with_numpy = model.predict(inputs, batch_size=10) - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.batch(10) + fit_with_ds = model.fit(dataset, epochs=1).history + eval_with_ds = model.evaluate(dataset) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + predict_dataset = predict_dataset.batch(10, drop_remainder=True) + predict_with_ds = model.predict(predict_dataset) + self.assertAllClose( + fit_with_numpy, fit_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + eval_with_numpy, eval_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4) + + @combinations.generate(all_strategy_combinations()) + def test_fit_eval_and_predict_methods_on_dataset(self, distribution): + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) dataset = get_dataset(distribution) @@ -833,10 +896,10 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(strategy_and_optimizer_combinations()) def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer): with self.cached_session(): - model = get_model() - - loss = 'mse' - model.compile(optimizer(), loss, distribute=distribution) + with distribution.scope(): + model = get_model() + loss = 'mse' + model.compile(optimizer(), loss) dataset = get_dataset(distribution) @@ -846,35 +909,39 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(strategy_minus_tpu_combinations()) def test_dataset_with_sample_weights(self, distribution): - model = get_model() - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) - - inputs = np.zeros((10, 3), np.float32) - targets = np.zeros((10, 4), np.float32) - sample_weights = np.ones((10), np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, - sample_weights)) - dataset = dataset.repeat() - dataset = dataset.batch(10) - - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) - model.evaluate(dataset, steps=2, verbose=1) - model.predict(dataset, steps=2) + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) + + inputs = np.zeros((10, 3), np.float32) + targets = np.zeros((10, 4), np.float32) + sample_weights = np.ones((10), np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, + sample_weights)) + dataset = dataset.repeat() + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) - def test_dataset_wrong_input_shape(self, distribution): + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_wrong_input_shape(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) # Wrong input shape inputs = np.zeros((10, 5), dtype=np.float32) @@ -888,15 +955,17 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) @combinations.generate(combinations.combine( - distribution=[combinations.mirrored_strategy_with_two_gpus], - mode=['graph'])) - def test_dataset_no_batch_input_validation(self, distribution): + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + # TODO(b/120943676, b/120957836): Re-enable once the validation code is + # restored. + def DISABLED_test_dataset_no_batch_input_validation(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) # User forgets to batch the dataset inputs = np.zeros((10, 3), dtype=np.float32) @@ -912,11 +981,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, mode=['graph'])) def test_dataset_input_shape_fully_defined(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) dataset = get_dataset(distribution) # Input shapes are not fully known. Batch dimension is unknown as we are @@ -928,24 +997,27 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(combinations.combine( distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) + mode=['graph', 'eager'])) def test_learning_phase_value(self, distribution): # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare # meaningful values. Currently we don't pass the learning phase if the # Lambda layer uses the learning phase. with self.cached_session(): - x = keras.layers.Input(shape=(1,), name='input') - y = keras.layers.Dense(1, kernel_initializer='ones')(x) - z = keras.layers.Dropout(0.9999)(y) - model = keras.Model(x, z) - initial_weights = model.get_weights() + with distribution.scope(): + x = keras.layers.Input(shape=(1,), name='input') + y = keras.layers.Dense(1, kernel_initializer='ones')(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + initial_weights = model.get_weights() - optimizer = gradient_descent.GradientDescentOptimizer(0.005) - loss = 'mse' - metrics = ['acc'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + model.compile(optimizer, loss, metrics=metrics) batch_size = 8 if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy): @@ -959,7 +1031,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase, hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) self.assertAlmostEqual(hist.history['acc'][0], 0, 0) - model.set_weights(initial_weights) + with distribution.scope(): + model.set_weights(initial_weights) # TODO(psv/anjalisridhar): Enable these lines after we fix b/117431185. # evaluate_output = model.evaluate(dataset, steps=20) # self.assertAlmostEqual(evaluate_output[1], 1, 0) @@ -973,14 +1046,14 @@ class TestDistributionStrategyWithDatasets(test.TestCase, ref_output = np.ones((160, 1), dtype=np.float32) self.assertArrayNear(output, ref_output, 1e-1) - @combinations.generate(strategy_minus_tpu_combinations()) + @combinations.generate(all_strategy_combinations()) def testOptimizerWithCallbacks(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent_keras.SGD(0.01) - loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent_keras.SGD(0.01) + loss = 'mse' + model.compile(optimizer, loss) dataset = get_dataset(distribution) @@ -989,11 +1062,185 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) - grouped_models = distribution.unwrap(model._grouped_model) + self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr)) + + @combinations.generate(tpu_strategy_combinations()) + def test_predict_with_dataset_with_partial_batch(self, distribution): + with self.cached_session(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + with distribution.scope(): - for m in grouped_models: - self.assertAllClose(0.001, keras.backend.get_value( - m.optimizer.lr), atol=1e-05, rtol=1e-05) + model_with_ds_strategy = get_model() + model_with_ds_strategy.compile(optimizer, loss) + + cpu_model = get_model() + cpu_model.compile(optimizer, loss) + + inputs = np.zeros((10, 3), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs)) + + # As sample size is 10, we batch by 4 so that the last batch is + # a partial batch. + dataset_with_partial_batch = dataset.batch(4) + cpu_model.set_weights(model_with_ds_strategy.get_weights()) + + self.assertAllClose( + model_with_ds_strategy.predict(dataset_with_partial_batch, steps=3), + cpu_model.predict(dataset_with_partial_batch, steps=3), + atol=1e-5, rtol=1e-5) + + @combinations.generate(tpu_strategy_combinations()) + def test_predict_multi_output_model_with_dataset_with_partial_batch( + self, distribution): + with self.cached_session(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + + with distribution.scope(): + model_with_ds_strategy = simple_multi_inputs_multi_outputs_model() + model_with_ds_strategy.compile(optimizer, loss) + + cpu_model = simple_multi_inputs_multi_outputs_model() + cpu_model.compile(optimizer, loss) + + input_data, _ = get_multi_inputs_multi_outputs_data() + input_dict = { + 'input_a': input_data['input_a'], + 'input_b': input_data['input_b'], + } + + dataset = dataset_ops.Dataset.from_tensor_slices(input_dict) + + # As sample size is 200, we batch by 18 using 12 steps per epoch so + # that the last batch is a partial batch. + dataset_with_partial_batch = dataset.batch(18) + cpu_model.set_weights(model_with_ds_strategy.get_weights()) + + self.assertAllClose( + model_with_ds_strategy.predict(dataset_with_partial_batch, steps=12), + cpu_model.predict(dataset_with_partial_batch, steps=12), + atol=1e-4, rtol=1e-4) + + +class Counter(keras.callbacks.Callback): + """Counts the number of times each callback method was run. + + Attributes: + method_counts: dict. Contains the counts of time each callback method was + run. + """ + + def __init__(self): + self.method_counts = collections.defaultdict(int) + methods_to_count = [ + 'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end', + 'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin', + 'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end', + 'on_test_begin', 'on_test_end', 'on_train_batch_begin', + 'on_train_batch_end', 'on_train_begin', 'on_train_end' + ] + for method_name in methods_to_count: + setattr(self, method_name, + self.wrap_with_counts(method_name, getattr(self, method_name))) + + def wrap_with_counts(self, method_name, method): + + def _call_and_count(*args, **kwargs): + self.method_counts[method_name] += 1 + return method(*args, **kwargs) + + return _call_and_count + + +class TestDistributionStrategyWithCallbacks(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations()) + def test_callbacks_in_fit(self, distribution): + with distribution.scope(): + model = get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = get_dataset(distribution) + counter = Counter() + + epochs = 2 + steps_per_epoch = 5 + validation_steps = 3 + + model.fit( + dataset, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + verbose=0, + validation_data=dataset, + validation_steps=validation_steps, + callbacks=[counter]) + + if isinstance(distribution, tpu_strategy.TPUStrategy): + # TPU Strategy can have multi step training, from extended.steps_per_run + # if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch + steps_per_run = distribution.extended.steps_per_run + num_batch_call_per_epoch = steps_per_epoch // steps_per_run + if steps_per_epoch % steps_per_run: + num_batch_call_per_epoch += 1 + else: + num_batch_call_per_epoch = steps_per_epoch + + self.assertDictEqual( + counter.method_counts, { + 'on_batch_begin': epochs * num_batch_call_per_epoch, + 'on_batch_end': epochs * num_batch_call_per_epoch, + 'on_epoch_begin': epochs, + 'on_epoch_end': epochs, + 'on_test_batch_begin': epochs * validation_steps, + 'on_test_batch_end': epochs * validation_steps, + 'on_test_begin': epochs, + 'on_test_end': epochs, + 'on_train_batch_begin': epochs * num_batch_call_per_epoch, + 'on_train_batch_end': epochs * num_batch_call_per_epoch, + 'on_train_begin': 1, + 'on_train_end': 1 + }) + + @combinations.generate(all_strategy_combinations()) + def test_callbacks_in_eval(self, distribution): + with distribution.scope(): + model = get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = get_dataset(distribution) + counter = Counter() + + model.evaluate(dataset, steps=5, callbacks=[counter]) + + self.assertDictEqual( + counter.method_counts, { + 'on_test_batch_begin': 5, + 'on_test_batch_end': 5, + 'on_test_begin': 1, + 'on_test_end': 1 + }) + + @combinations.generate(all_strategy_combinations()) + def test_callbacks_in_predict(self, distribution): + with distribution.scope(): + model = get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = get_dataset(distribution) + counter = Counter() + + model.predict(get_predict_dataset(dataset), steps=5, callbacks=[counter]) + + self.assertDictEqual( + counter.method_counts, { + 'on_predict_batch_begin': 5, + 'on_predict_batch_end': 5, + 'on_predict_begin': 1, + 'on_predict_end': 1 + }) class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): @@ -1002,22 +1249,23 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): distribution=[ combinations.mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + mode=['graph', 'eager'])) def test_validating_dataset_input_tensors_with_shape_mismatch(self, distribution): with self.cached_session(): a = constant_op.constant([1, 2], shape=(1, 2)) b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) - x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) - y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with distribution.scope(): - # Removed device and input tensor shape details from the error message - # since the order of the device and the corresponding input tensor shape - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor shapes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + with distribution.scope(): distributed_training_utils.validate_distributed_dataset_inputs( distribution, x, y) @@ -1025,38 +1273,39 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): distribution=[ combinations.mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph'])) + mode=['graph', 'eager'])) def test_validating_dataset_input_tensors_with_dtype_mismatch(self, distribution): with self.cached_session(): a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) - x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) - y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with distribution.scope(): - # Removed device and input tensor dtype details from the error message - # since the order of the device and the corresponding input tensor dtype - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor dtypes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + with distribution.scope(): distributed_training_utils.validate_distributed_dataset_inputs( distribution, x, y) @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) def test_unsupported_features(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) dataset = get_dataset(distribution) @@ -1081,31 +1330,36 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): verbose=0, sample_weight=sample_weight) - # Test with not specifying the `steps` argument. - with self.assertRaisesRegexp( - ValueError, 'you should specify the `steps_per_epoch` argument'): + # Test with not specifying the `steps` argument for dataset with infinite + # cardinality. + dataset = dataset.repeat() + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps_per_epoch` argument'): model.fit(dataset, epochs=1, verbose=0) - with self.assertRaisesRegexp(ValueError, - 'you should specify the `steps` argument'): + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): model.evaluate(dataset, verbose=0) - with self.assertRaisesRegexp(ValueError, - 'you should specify the `steps` argument'): + with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): model.predict(dataset, verbose=0) @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) def test_calling_with_unsupported_predefined_callbacks(self, distribution): with self.cached_session(): - model = get_model() - - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) dataset = get_dataset(distribution) @@ -1122,12 +1376,6 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'using'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.ReduceLROnPlateau()]) - with self.assertRaisesRegexp(ValueError, - 'histogram_freq in the TensorBoard callback ' - 'is not supported when using ' - 'DistributionStrategy.'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) class TestDistributionStrategyWithLossMasking(test.TestCase, @@ -1137,21 +1385,21 @@ class TestDistributionStrategyWithLossMasking(test.TestCase, # work for TPU due to some invalid datatype. @combinations.generate(combinations.combine( distribution=[ - combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], - mode=['graph'])) + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) def test_masking(self, distribution): with self.cached_session(): np.random.seed(1337) x = np.array([[[1], [1]], [[0], [0]]]) - model = keras.models.Sequential() - model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(1, kernel_initializer='one'))) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=distribution) + with distribution.scope(): + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01)) y = np.array([[[1], [1]], [[1], [1]]]) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.repeat(100) @@ -1163,15 +1411,18 @@ class TestDistributionStrategyWithLossMasking(test.TestCase, class TestDistributionStrategyWithNormalizationLayer( test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) - def test_batchnorm_correctness(self, distribution): + @combinations.generate(combinations.times( + all_strategy_combinations(), + combinations.combine(fused=[True, False]))) + def test_batchnorm_correctness(self, distribution, fused): with self.cached_session(): - model = keras.models.Sequential() - norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) - model.add(norm) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=distribution) + with distribution.scope(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization( + input_shape=(10,), momentum=0.8, fused=fused) + model.add(norm) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01)) # centered on 5.0, variance 10.0 x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) @@ -1192,118 +1443,77 @@ class TestDistributionStrategyWithNormalizationLayer( np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) -class TestDistributionStrategyCorrectness(test.TestCase, - parameterized.TestCase): +class TestDistributionStrategySaveLoadWeights(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_save_load_h5(self, distribution): + with self.cached_session(): + dataset = get_dataset(distribution) + with distribution.scope(): + model = get_model() + model.compile(gradient_descent_keras.SGD(0.01), 'mse') + model.fit(dataset, epochs=1, steps_per_epoch=1) + + weights_file = tempfile.mktemp('.h5') + model.save_weights(weights_file) + + model_2 = get_model() + model_2.compile(gradient_descent_keras.SGD(0.01), 'mse') + model_2.load_weights(weights_file) + model_2.predict(get_predict_dataset(distribution), steps=2) + model_2.fit(dataset, epochs=1, steps_per_epoch=1) + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_save_load_checkpointable(self, distribution): + # TODO(sourabhbajaj): Test fails with optimizer v2 without h5 + with self.cached_session(): + dataset = get_dataset(distribution) + with distribution.scope(): + model = get_model() + model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') + model.fit(dataset, epochs=1, steps_per_epoch=1) + + weights_file = tempfile.mktemp() + model.save_weights(weights_file) + + model_2 = get_model() + model_2.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') + model_2.load_weights(weights_file) + model_2.predict(get_predict_dataset(distribution), steps=2) + model_2.fit(dataset, epochs=1, steps_per_epoch=1) - @combinations.generate(strategy_combinations()) - def test_metric_correctness(self, distribution): + +class TestDistributionStrategyValidation(test.TestCase, + parameterized.TestCase): + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_layer_outside_scope(self, distribution): with self.cached_session(): - keras.backend.set_image_data_format('channels_last') - num_samples = 10000 - - x_train = np.random.randint(0, 2, num_samples) - x_train = np.reshape(x_train, (num_samples, 1)) - y_train = x_train - x_train = x_train.astype('float32') - y_train = y_train.astype('float32') - - # Create identity model. - model = keras.Sequential() - model.add( - keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones')) - model.compile( - loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5), - metrics=[keras.metrics.BinaryAccuracy()], - distribute=distribution) - - batch_size = 64 - if not distributed_training_utils.global_batch_size_supported( - distribution): - batch_size //= distribution.num_replicas_in_sync - train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) - train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - - history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) - self.assertEqual(history.history['binary_accuracy'], [1.0]) - - @combinations.generate(strategy_and_inputs()) - def test_correctness(self, distribution, use_numpy): + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + with distribution.scope(): + model = keras.Model(x, y) + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + @combinations.generate(all_strategy_combinations_minus_default()) + def test_model_outside_scope(self, distribution): with self.cached_session(): - tolerance = 1e-5 - - if isinstance(distribution, (mirrored_strategy.MirroredStrategy, - mirrored_strategy.CoreMirroredStrategy)): - # TODO(b/119257215): use the default one once the flakyness is fixed. - tolerance = 1e-4 - - keras.backend.set_image_data_format('channels_last') - np.random.seed(_RANDOM_SEED) - random_seed.set_random_seed(_RANDOM_SEED) - - # Train, eval, and predict datasets are created with the same input numpy - # arrays. - # TODO(xiejw): Change this back to 10000, once we support final partial - # batch. - num_samples = 9984 - x_train = np.random.rand(num_samples, 1) - y_train = 3 * x_train - x_train = x_train.astype('float32') - y_train = y_train.astype('float32') - x_predict = [[1.], [2.], [3.], [4.]] - - # The model is built once and the initial weights are saved. - # This is used to initialize the model for both the distribution and - # non-distribution run. In addition, we add few non-linear layers to make - # it non-trivial. - model = keras.Sequential() - model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(1)) - initial_weights = model.get_weights() - - def fit_eval_and_predict(with_distribution=None): - # We have initialized the model to the same weight for the distribution - # and non-distribution run. - model.set_weights(initial_weights) - # TODO(b/120245072): Also use gradient_descent_keras.SGD for - # TPUStrategy. - # pylint: disable=line-too-long - if with_distribution and with_distribution.__class__.__name__ == 'TPUStrategy': - # pylint: enable=line-too-long - optimizer = gradient_descent.GradientDescentOptimizer(0.5) - else: - optimizer = gradient_descent_keras.SGD(0.5) - model.compile( - loss=keras.losses.mean_squared_error, - optimizer=optimizer, - distribute=with_distribution) - - training_inputs, eval_inputs, predict_inputs = ( - get_correctness_test_inputs(use_numpy, with_distribution, - x_train, y_train, x_predict)) - - model.fit(**training_inputs) - eval_result = model.evaluate(**eval_inputs) - weights = model.get_weights() - predict_result = model.predict(**predict_inputs) - - return weights, eval_result, predict_result - - wts_with_ds, eval_with_ds, predict_with_ds = fit_eval_and_predict( - with_distribution=distribution) - wts_without_ds, eval_without_ds, predict_without_ds = ( - fit_eval_and_predict(with_distribution=None)) - - # Verify that the weights, eval results, predict outputs are the same - # within some limits of tolerance. - self.assertAllClose( - wts_with_ds, wts_without_ds, atol=tolerance, rtol=tolerance) - self.assertAllClose( - eval_with_ds, eval_without_ds, atol=tolerance, rtol=tolerance) - self.assertAllClose( - predict_with_ds, predict_without_ds, atol=tolerance, rtol=tolerance) + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + with distribution.scope(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 8ac659abe96370b751ed1556cc699fe20788a0fd..a663e809dd45ea099e1d8a08e681d07b05bee3c9 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -95,16 +95,15 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): with ops.Graph().as_default(), distribution.scope(): - iterator = distribution.distribute_dataset( - dataset_fn).make_initializable_iterator() + iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): - value, update = distribution.call_for_each_replica( - metric_fn, args=inputs) + value, update = distribution.extended.call_for_each_replica( + metric_fn, args=(inputs,)) ctx.set_non_tensor_output(name="value", output=value) return distribution.group(update) - ctx = distribution.run_steps_on_dataset( + ctx = distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=distribution.extended.steps_per_run) update = ctx.run_op value = ctx.non_tensor_outputs["value"] @@ -114,15 +113,14 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): distribution.num_replicas_in_sync * distribution.extended.steps_per_run) else: - value, update = distribution.call_for_each_replica( - metric_fn, iterator.get_next()) + value, update = distribution.extended.call_for_each_replica( + metric_fn, args=(iterator.get_next(),)) update = distribution.group(update) # TODO(josh11b): Once we switch to using a global batch size for input, # replace "distribution.num_replicas_in_sync" with "1". batches_per_update = distribution.num_replicas_in_sync - self.evaluate(iterator.initializer) - self.evaluate(distribution.initialize()) + self.evaluate(iterator.initialize()) self.evaluate(variables.local_variables_initializer()) batches_consumed = 0 @@ -136,8 +134,6 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): if batches_consumed >= 4: # Consume 4 input batches in total. break - self.evaluate(distribution.finalize()) - @combinations.generate(all_combinations() + tpu_combinations()) def testMean(self, distribution): def _dataset_fn(): diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index dcc9df4cda51b87e95fb166a726170a8817715fc..f06c9b75644b2890b7657f75e74e4e20a6f15705 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -41,12 +41,9 @@ from tensorflow.python.ops.losses import losses_impl class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): - def _get_iterator(self, ds): - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() - self.evaluate(iterator.initializer) + def _get_iterator(self, strategy, input_fn): + iterator = strategy.make_input_fn_iterator(lambda _: input_fn()) + self.evaluate(iterator.initialize()) return iterator @combinations.generate( @@ -67,15 +64,15 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=2).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -84,12 +81,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): weights, biases = [], [] for _ in range(5): run_step() - weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing) @@ -105,11 +99,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): return distribution.group( - distribution.call_for_each_replica( + distribution.extended.call_for_each_replica( model_fn, args=(iterator.get_next(),))) if not context.executing_eagerly(): @@ -152,7 +146,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # `distribution.scope`. with variable_scope.variable_creator_scope( appending_creator), distribution.scope(): - model_fn, dataset_fn, layer = minimize_loss_example( + model_fn, dataset_fn, _ = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=True, @@ -161,24 +155,21 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) - run_step() - self.evaluate(distribution.finalize()) - def get_expected_variables(optimizer_fn, num_parameter_devices): variables_map = { "GradientDescent": ["dense/kernel", "dense/bias"], @@ -197,7 +188,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.assertEqual( get_expected_variables(optimizer_fn, - len(distribution.parameter_devices)), + len(distribution.extended.parameter_devices)), set(created_variables)) @combinations.generate( @@ -230,18 +221,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused fetches = distribution.unwrap( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) if update_ops_in_cross_replica_mode: - fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) + fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) return control_flow_ops.group(fetches) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -267,8 +258,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) - self.evaluate(distribution.finalize()) - @combinations.generate( combinations.times( combinations.combine( @@ -302,8 +291,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): with distribution.scope(): all_vars = [] - def model_fn(x, y): - + def model_fn(inputs): + x, y = inputs def loss_fn(): # Use fixed initialization to make the steps deterministic. w = variable_scope.get_variable("w", initializer=[[2.]]) @@ -327,15 +316,15 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=inputs)) + distribution.extended.call_for_each_replica( + model_fn, args=(inputs,))) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): - return distribution.run_steps_on_dataset( + return distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=1).run_op - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -370,8 +359,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) - self.evaluate(distribution.finalize()) - @combinations.generate( combinations.times( combinations.distributions_and_v1_optimizers(), @@ -412,8 +399,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): return (train_op, loss) def step_fn(output_context, inputs): - (train_op, loss) = distribution.call_for_each_replica( - model_fn, args=(output_context,) + inputs) + (train_op, loss) = distribution.extended.call_for_each_replica( + model_fn, args=(output_context, inputs)) output_context.set_last_step_output( name="cross_replica_loss_reduced", output=loss, @@ -423,7 +410,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): output=loss) return distribution.group(train_op) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = self._get_iterator(distribution, dataset_fn) def run_step(): initial_loss = lambda: constant_op.constant(1e7) @@ -439,11 +426,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): "cross_replica_loss_not_reduced": distribution.unwrap(distribution.broadcast(initial_loss())) } - ctx = distribution.run_steps_on_dataset( + ctx = distribution.extended.experimental_run_steps_on_iterator( step_fn, iterator, iterations=2, initial_loop_values=initial_loop_values) - self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) + self.assertEqual({key1: (value1,)}, ctx.non_tensor_outputs) self._verify_loss_output( initial_loss(), loss_output=ctx.last_step_outputs["replica_loss_reduced"], @@ -458,7 +445,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): reduced=False, distribution=distribution) return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"]) - self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.cached_session() as sess: run_step = sess.make_callable(run_step()) @@ -471,8 +457,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:])) self.assertTrue(loss_is_not_increasing) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 4a594f056e96a2a48563d9902bdeed8458b847e4..5391e083fc9b3ed99cc64bbed11bdeb8dea07f93 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -18,18 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools - -from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_strategy -from tensorflow.python.distribute import values # pylint: disable=protected-access,invalid-name _call_for_each_replica = mirrored_strategy._call_for_each_replica -_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value _create_mirrored_variable = mirrored_strategy._create_mirrored_variable +all_local_devices = mirrored_strategy.all_local_devices CoreMirroredStrategy = mirrored_strategy.MirroredStrategy CoreMirroredExtended = mirrored_strategy.MirroredExtended # pylint: enable=protected-access,invalid-name @@ -49,8 +46,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): distributed environment. There are several important concepts for distributed TensorFlow, e.g. - `client`, `job`, 'task', `cluster`, `in-graph replication` and - 'synchronous training' and they have already been defined in the + `client`, `job`, `task`, `cluster`, `in-graph replication` and + `synchronous training` and they have already been defined in the [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). The distribution strategy inherits these concepts as well and in addition to that we also clarify several more concepts: @@ -105,6 +102,61 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): auto_shard_dataset) super(MirroredStrategy, self).__init__(extended) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def make_dataset_iterator(self, dataset): # pylint: disable=useless-super-delegation + """Makes an iterator for input provided via `dataset`. + + NOTE: The batch size of the `dataset` argument is treated differently for + this contrib version of `MirroredStrategy`. + + Data from the given dataset will be distributed evenly across all the + compute replicas. We will assume that the input dataset is batched by the + per-replica batch size. + + The user could also use `make_input_fn_iterator` if they want to + customize which input is fed to which replica/worker etc. + + Args: + dataset: `tf.data.Dataset` that will be distributed evenly across all + replicas. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. + """ + return super(MirroredStrategy, self).make_dataset_iterator(dataset) + + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation + self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): + """Makes an iterator for input provided via a nest of numpy arrays. + + NOTE: The `batch_size` argument here has different behavior for this + contrib version of `MirroredStrategy`. + + Args: + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. + batch_size: The number of entries from the array we should consume in one + step of the computation, across all replicas. This is the per-replica + batch size. The global batch size will be this times + `num_replicas_in_sync`. + num_epochs: The number of times to iterate through the examples. A value + of `None` means repeat forever. + shuffle: Size of buffer to use for shuffling the input examples. + Use `None` to disable shuffling. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. + """ + return super(MirroredStrategy, self).experimental_make_numpy_iterator( + numpy_input, batch_size, num_epochs, shuffle, session) + class MirroredExtended(CoreMirroredExtended): """Implementation of (contrib) MirroredStrategy.""" @@ -115,8 +167,13 @@ class MirroredExtended(CoreMirroredExtended): num_gpus_per_worker=None, cross_device_ops=None, auto_shard_dataset=False): - super(MirroredExtended, self).__init__( - container_strategy, devices, num_gpus_per_worker, cross_device_ops) + if devices is None: + devices = mirrored_strategy.all_local_devices(num_gpus_per_worker) + elif num_gpus_per_worker is not None: + raise ValueError( + "Must only specify one of `devices` and `num_gpus_per_worker`.") + super(MirroredExtended, self).__init__(container_strategy, devices, + cross_device_ops) self._auto_shard_dataset = auto_shard_dataset def _make_dataset_iterator(self, dataset): @@ -131,24 +188,10 @@ class MirroredExtended(CoreMirroredExtended): Returns: An `InputIterator` which returns inputs for each step of the computation. """ - if self._cluster_spec: - worker_device_pairs = self._worker_devices - else: - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, self._devices)] - return values.DatasetIterator(dataset, worker_device_pairs) - - def _distribute_dataset(self, dataset_fn): - if self._cluster_spec: - return values.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), - self._worker_devices, - auto_shard=self._auto_shard_dataset) - else: - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._devices) + return input_lib.DatasetIterator(dataset, self._input_workers) # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """The contrib version of Mirrored strategy uses per-replica batch size.""" return False diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index fee37daa424b8ada9f18b2046599a62647d8c33d..0b8df787e6b1bde8dce30ea420a3f0e19da23ca4 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -39,6 +39,7 @@ from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.layers import core as keras_core @@ -65,8 +66,10 @@ GPU_TEST = "test_gpu" in sys.argv[0] combinations.core_mirrored_strategy_with_gpu_and_cpu, combinations.core_mirrored_strategy_with_two_gpus], mode=["graph", "eager"])) -class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class MirroredTwoDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): def testMinimizeLoss(self, distribution): if context.executing_eagerly(): @@ -100,7 +103,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, expected = sum(range(distribution.num_replicas_in_sync)) self.assertEqual(expected, self.evaluate(reduced)) - def testMakeInputFnIterator(self, distribution): + def testMakeInputFnIteratorWithDataset(self, distribution): dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i, i+1] for i in range(0, 10, 2)] @@ -113,9 +116,47 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, expected_values) + def testMakeInputFnIteratorWithCallable(self, distribution): + def fn(): + dataset = dataset_ops.Dataset.range(2).interleave( + (lambda _: dataset_ops.Dataset.range(10)), cycle_length=2) + it = dataset.make_one_shot_iterator() + return it.get_next + expected_values = [[i, i] for i in range(0, 10)] + + input_fn = self._input_fn_to_test_input_context( + fn, + expected_num_replicas_in_sync=2, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, + expected_values, test_reinitialize=False) + + def testNumpyIterator(self, distribution): + self._test_numpy_iterator(distribution) + def testGlobalStepUpdate(self, distribution): self._test_global_step_update(distribution) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) + def one_device_combinations(): return combinations.combine( @@ -127,25 +168,42 @@ def one_device_combinations(): mode=["graph", "eager"]) +@combinations.generate(one_device_combinations()) class MirroredOneDeviceDistributionTest( strategy_test_lib.DistributionTestBase, + strategy_test_lib.OneDeviceDistributionTestBase, parameterized.TestCase): - @combinations.generate(one_device_combinations()) def testMinimizeLoss(self, distribution): if context.executing_eagerly(): self._test_minimize_loss_eager(distribution) else: self._test_minimize_loss_graph(distribution) - @combinations.generate(one_device_combinations()) def testReplicaId(self, distribution): self._test_replica_id(distribution) - @combinations.generate(one_device_combinations()) def testCallAndMergeExceptions(self, distribution): self._test_call_and_merge_exceptions(distribution) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) + class MirroredStrategyVariableCreatorStackTest( test.TestCase, parameterized.TestCase): @@ -179,9 +237,37 @@ class MirroredStrategyVariableCreatorStackTest( variable_scope.variable_creator_scope(main_thread_creator): result = distribution.extended.call_for_each_replica(model_fn) result = distribution.unwrap(result) - expected = ["main_thread:thread_0", "main_thread:thread_1"] + expected = ("main_thread:thread_0", "main_thread:thread_1") self.assertEqual(expected, result) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class MirroredStrategyCallForEachReplicaTest(test.TestCase): + + def testExecutingEagerlyOutsideFunction(self, distribution): + """Verify we preserve the value of executing_eagerly_outside_functions().""" + def model_fn(): + return ops.executing_eagerly_outside_functions() + + originally = ops.executing_eagerly_outside_functions() + with distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + + # Verify this all again, but this time in a FuncGraph. + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + in_scope = ops.executing_eagerly_outside_functions() + in_model_fn = distribution.extended.call_for_each_replica(model_fn) + unwrapped = distribution.unwrap(in_model_fn) + self.assertEqual(in_scope, unwrapped[0]) + self.assertEqual(in_scope, originally) + @combinations.generate(combinations.combine( distribution=[ @@ -190,6 +276,29 @@ class MirroredStrategyVariableCreatorStackTest( mode=["graph", "eager"])) class MirroredStrategyVariableCreationTest(test.TestCase): + # TODO(priyag): Modify more tests to use this helper and check more + # properties. + def _test_mv_properties(self, var, name, strategy): + self.assertIsInstance(var, values.MirroredVariable) + self.assertEqual(name, var.name) + self.assertIs(strategy, var.distribute_strategy) + for d in var.devices: + self.assertEqual(d, var.get(d).device) + self.assertIs(strategy, var.get(d)._distribute_strategy) # pylint: disable=protected-access + + def testVariableInFuncGraph(self, distribution): + def model_fn(): + v = variable_scope.variable(2.0, name="bar") + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + v1 = variable_scope.variable(1.0, name="foo") + v2 = distribution.extended.call_for_each_replica(model_fn) + + self._test_mv_properties(v1, "foo:0", distribution) + self._test_mv_properties(v2, "bar:0", distribution) + def testSingleVariable(self, distribution): def model_fn(): # This variable should be created only once across the threads because of @@ -201,8 +310,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) - self.assertIsInstance(result, values.MirroredVariable) - self.assertEqual("foo:0", result.name) + self._test_mv_properties(result, "foo:0", distribution) def testUnnamedVariable(self, distribution): def model_fn(): @@ -212,9 +320,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) - self.assertIsInstance(result, values.MirroredVariable) - # Default name of "Variable" will be used. - self.assertEqual("Variable:0", result.name) + self._test_mv_properties(result, "Variable:0", distribution) def testMultipleVariables(self, distribution): def model_fn(): @@ -227,8 +333,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with distribution.scope(): result = distribution.extended.call_for_each_replica(model_fn) for i, v in enumerate(result): - self.assertIsInstance(v, values.MirroredVariable) - self.assertEqual("foo" + str(i) + ":0", v.name) + self._test_mv_properties(v, "foo" + str(i) + ":0", distribution) def testMultipleVariablesWithSameCanonicalName(self, distribution): def model_fn(): @@ -278,14 +383,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase): (layer2.kernel, layer2.bias), (layer3.kernel, layer3.bias)] - ds = distribution.distribute_dataset( - lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() - self.evaluate([iterator.initializer]) - + iterator = distribution.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) + self.evaluate(iterator.initialize()) features = iterator.get_next() with distribution.scope(): @@ -512,10 +612,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v with distribution.scope(): - names = values.DistributedValues({ - "/device:CPU:0": "foo", - "/device:GPU:0": "bar" - }) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + names = values.DistributedValues(device_map, ("foo", "bar")) with self.assertRaises(RuntimeError): _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) @@ -649,6 +747,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): distribution.extended.worker_devices[0]).read_value())) self.assertEqual(10.0, self.evaluate(ret_v_sum)) + def testVarDistributeStrategy(self, distribution): + with distribution.scope(): + mirrored = variable_scope.variable(1.0) + replica_local = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ) + self.assertIs(distribution, mirrored.distribute_strategy) + self.assertIs(distribution, replica_local.distribute_strategy) + @combinations.generate(combinations.combine( distribution=[ @@ -757,21 +864,23 @@ class MirroredStrategyNameScopeTest(test.TestCase): self.assertEqual("c/replica_1:0", c1.name) -@combinations.generate(combinations.combine( - distribution=[ - combinations.NamedDistribution( - "Mirrored3Devices", - # pylint: disable=g-long-lambda - lambda: mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), - required_gpus=2), - combinations.NamedDistribution( - "CoreMirrored3Devices", - # pylint: disable=g-long-lambda - lambda: mirrored_strategy.CoreMirroredStrategy( - ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), - required_gpus=2)], - mode=["graph", "eager"])) +@combinations.generate( + combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Mirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2), + combinations.NamedDistribution( + "CoreMirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2) + ], + mode=["graph", "eager"])) class MirroredThreeDeviceDistributionTest( strategy_test_lib.DistributionTestBase, parameterized.TestCase): @@ -1075,7 +1184,7 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # When we read the value using `read_var` we should see the SUM of each of # values on each of the replicas. self.assertEqual(2.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) # Assigning 6.0 in cross replica context will assign a value of # 6.0/num_replicas to each replica. tlv_ops = replica_local_var.assign(6.0) @@ -1084,7 +1193,7 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # The value on all the replicas are added before being returned by # `read_var`. self.assertEqual(6.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) def testAssignReplicaLocalVarMeanAggregation(self, distribution): def model_fn(): @@ -1103,13 +1212,13 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # When we read the value using `read_var` we should see the MEAN of values # on all replicas which is the value assigned in replica context. self.assertEqual(1.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) tlv_ops = replica_local_var.assign(6.0) self.evaluate(tlv_ops) # On reading the replica local var we should get the MEAN of all values # which is equal to the value assigned. self.assertEqual(6.0, self.evaluate( - distribution.read_var(replica_local_var))) + distribution.extended.read_var(replica_local_var))) class MockModel(object): @@ -1162,14 +1271,14 @@ class MirroredStrategyDefunTest(test.TestCase): result = distribution.extended.call_for_each_replica( model_fn, args=[mock_model] + inputs) - for device in devices: - device_result = values.select_device(device, result) - device_expected_result = values.select_device(device, expected_result) + for r in range(len(devices)): + device_result = values.select_replica(r, result) + device_expected_result = values.select_replica(r, expected_result) self.assertAllClose(device_expected_result, self.evaluate(device_result)) for defun in defuns: - # PolymorphicFunctions are specialized to the current device stack, so + # `Function`s are specialized to the current device stack, so # call_for_each has one trace per device. To check that the expected set # of variables was accessed on each trace, we first retrieve each # device-specific graph function. @@ -1245,9 +1354,9 @@ class MirroredStrategyDefunTest(test.TestCase): def fn1(mock_model, factor): return mock_model(factor) - factors = values.PerReplica({"CPU:0": 5.0, "GPU:0": 3.0}) - expected_result = values.PerReplica({"CPU:0": 5.0 * 1.25, - "GPU:0": 3.0 * 1.25}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + factors = values.PerReplica(device_map, (5.0, 3.0)) + expected_result = values.PerReplica(device_map, (5.0 * 1.25, 3.0 * 1.25)) self._call_and_check(distribution, fn1, [factors], expected_result, [fn1]) def testTrain(self, distribution): @@ -1283,14 +1392,14 @@ class MirroredStrategyDefunTest(test.TestCase): combinations.NamedDistribution( "Mirrored", # pylint: disable=g-long-lambda - lambda: mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=context.num_gpus()), + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker= + context.num_gpus()), required_gpus=1), combinations.NamedDistribution( "CoreMirrored", # pylint: disable=g-long-lambda lambda: mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=context.num_gpus()), + mirrored_strategy.all_local_devices()), required_gpus=1) ], mode=["graph"])) @@ -1324,7 +1433,7 @@ class MultiWorkerMirroredStrategyTest( self.assertEqual(a.device, "/job:worker/task:0") self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") - def testMakeInputFnIterator(self, distribution): + def testMakeInputFnIteratorWithDataset(self, distribution): self._configure_distribution_strategy(distribution) dataset_fn = lambda: dataset_ops.Dataset.range(100) num_gpus = context.num_gpus() @@ -1345,6 +1454,32 @@ class MultiWorkerMirroredStrategyTest( self._test_input_fn_iterator( iterator, distribution.extended.worker_devices, expected_values, sess) + def testMakeInputFnIteratorWithCallable(self, distribution): + self._configure_distribution_strategy(distribution) + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next + num_gpus = context.num_gpus() + num_workers = 2 + + expected_values = [] + for i in range(0, 100, num_gpus): + expected_values.append([i+j for j in range(num_gpus)] * num_workers) + + with context.graph_mode(), self.cached_session() as sess: + # `expected_input_pipeline_id` is None because the input_fn will be called + # multiple times, each with a different input_pipeline_id. + input_fn = self._input_fn_to_test_input_context( + fn, + expected_num_replicas_in_sync=num_workers*num_gpus, + expected_num_input_pipelines=num_workers, + expected_input_pipeline_id=None) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, distribution.extended.worker_devices, expected_values, sess, + test_reinitialize=False) + def testUpdateConfigProto(self, distribution): distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]}) @@ -1374,7 +1509,7 @@ class MultiWorkerMirroredStrategyTestWithChief( def testMinimizeLossGraphCoreMirroredStrategy(self): strategy = mirrored_strategy.CoreMirroredStrategy( - num_gpus_per_worker=context.num_gpus()) + mirrored_strategy.all_local_devices()) strategy.configure(cluster_spec=self._cluster_spec) self._test_minimize_loss_graph(strategy, learning_rate=0.05) diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py index 17b7ab74f63f42e1ee14a82d3bffdd1df9b25857..53e35ea6b75088a3de9866973f872da4a4ce25d6 100644 --- a/tensorflow/contrib/distribute/python/monitor.py +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -51,7 +51,7 @@ class Monitor(object): else: if session is None: raise ValueError("Should provide a `session` in Graph mode.") - session.run(step_callable._iterator.initializer) # pylint: disable=protected-access + session.run(step_callable.initialize()) self._run_step = session.make_callable(step_callable()) session.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 16be839e1d155003b9490fbe3da6ab85b7d2d78a..c0651610cafc06a6d5f4206f4e64d27020fae30b 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -23,9 +23,9 @@ import numpy from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import monitor as monitor_lib -from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example from tensorflow.python.client import session +from tensorflow.python.distribute import one_device_strategy from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index c492d8bafc9024ed059f05b92e5466f3702726b9..c4622cdd2af2f6a9c936fe554bcc2eb76f805fdc 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -53,7 +53,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): return var, assign with distribution.scope(), self.cached_session() as sess: - var, assign = distribution.call_for_each_replica(replica_fn) + var, assign = distribution.extended.call_for_each_replica(replica_fn) variables.global_variables_initializer().run() self.assertAllClose([10.0, 11.0], var.eval()) sess.run(distribution.unwrap(assign)) @@ -79,7 +79,7 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): return var, assign.op with distribution.scope(), self.cached_session() as sess: - var, assign_op = distribution.call_for_each_replica(replica_fn) + var, assign_op = distribution.extended.call_for_each_replica(replica_fn) variables.global_variables_initializer().run() self.assertAllClose([0.0, 0.0], var.eval()) sess.run(distribution.unwrap(assign_op)) @@ -139,6 +139,27 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)], var.eval()) + @combinations.generate(all_combinations) + def testAssignVariable(self, distribution): + + def replica_fn(): + var = variables.Variable([10.0, 11.0]) + # Here we expect to check the case when input value are variable. + val = variables.Variable([1., 2.]) + decay = 0.25 + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + return var, assign + + with distribution.scope(), self.cached_session() as sess: + var, assign = distribution.extended.call_for_each_replica(replica_fn) + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(distribution.unwrap(assign)) + self.assertAllClose( + [10 * 0.25 + 1. * (1 - 0.25), 11 * 0.25 + 2. * (1 - 0.25)], + var.eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 147c9b83f866fd364ea23cf7988692a7b5f61b9c..b05aac431f65b4281d9ed9c2fa95c210d55f4008 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -40,6 +40,7 @@ from tensorflow.python.client import session from tensorflow.python.estimator import run_config from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import coordinator from tensorflow.python.training import server_lib ASSIGNED_PORTS = set() @@ -360,6 +361,7 @@ class IndependentWorkerTestBase(test.TestCase): self._mock_os_env = MockOsEnv() self._mock_context = test.mock.patch.object(os, 'environ', self._mock_os_env) + self._coord = coordinator.Coordinator() super(IndependentWorkerTestBase, self).setUp() self._mock_context.__enter__() @@ -368,8 +370,9 @@ class IndependentWorkerTestBase(test.TestCase): super(IndependentWorkerTestBase, self).tearDown() def _task_thread(self, task_fn, tf_config, *args, **kwargs): - os.environ['TF_CONFIG'] = json.dumps(tf_config) - task_fn(*args, **kwargs) + with self._coord.stop_on_exception(): + os.environ['TF_CONFIG'] = json.dumps(tf_config) + task_fn(*args, **kwargs) def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, *args, **kwargs): @@ -403,3 +406,6 @@ class IndependentWorkerTestBase(test.TestCase): *args, **kwargs) threads[task_type].append(t) return threads + + def join_independent_workers(self, worker_threads): + self._coord.join(worker_threads) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index e322b6acb84c166a885c9aaa3002f331903a5063..13a501394ee1fec2dfc1427f6d16d3a4624d7747 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -18,202 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six +from tensorflow.python.distribute import one_device_strategy -from tensorflow.python.distribute import device_util -from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import values -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.util import nest - - -# TODO(josh11b): Replace asserts in this file with if ...: raise ... - - -class OneDeviceStrategy(distribute_lib.DistributionStrategy): - """A distribution strategy for running on a single device.""" - # TODO(josh11b): Do we wrap values in types to generate errors if you are - # doing something that won't work with other DistributionStrategy - # implementations? - - def __init__(self, device): - super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device)) - - -class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): - """Implementation of OneDeviceStrategy.""" - - def __init__(self, container_strategy, device): - super(OneDeviceExtended, self).__init__(container_strategy) - self._device = device - self._default_device = device - - def _create_variable(self, next_creator, *args, **kwargs): - colocate_with = kwargs.pop("colocate_with", None) - if colocate_with is None: - with ops.device(self._device): - return next_creator(*args, **kwargs) - if isinstance(colocate_with, six.string_types): - with ops.device(colocate_with): - return next_creator(*args, **kwargs) - if (isinstance(colocate_with, list) and len(colocate_with) == 1 and - isinstance(colocate_with[0], six.string_types)): - with ops.device(colocate_with[0]): - return next_creator(*args, **kwargs) - with ops.colocate_with(colocate_with): - return next_creator(*args, **kwargs) - - def _make_dataset_iterator(self, dataset): - """Make iterator from dataset without splitting the batch.""" - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, [self._device])] - return values.DatasetIterator(dataset, worker_device_pairs) - - def _distribute_dataset(self, dataset_fn): - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), [self._device]) - - def _make_input_fn_iterator( - self, - input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - worker = device_util.canonicalize("/device:CPU:0") - worker_device_pairs = [(worker, [self._device])] - return values.InputFunctionIterator( - input_fn, worker_device_pairs, - [distribute_lib.InputContext()]) - - def _broadcast_to(self, tensor, destinations): - del destinations - return tensor - - # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, - initial_loop_values=None): - if initial_loop_values is None: - initial_loop_values = {} - initial_loop_values = nest.flatten(initial_loop_values) - - ctx = values.MultiStepContext() - def body(i, *args): - """A wrapper around `fn` to create the while loop body.""" - del args - fn_inputs = iterator.get_next() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, fn_inputs) - flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) - with ops.control_dependencies([fn_result]): - return [i + 1] + flat_last_step_outputs - - # We capture the control_flow_context at this point, before we run `fn` - # inside a while_loop. This is useful in cases where we might need to exit - # these contexts and get back to the outer context to do some things, for - # e.g. create an op which should be evaluated only once at the end of the - # loop on the host. One such usage is in creating metrics' value op. - self._outer_control_flow_context = ( - ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - - # TODO(priyag): Use max_iterations instead of an explicit counter. - cond = lambda i, *args: i < iterations - i = constant_op.constant(0) - loop_result = control_flow_ops.while_loop( - cond, body, [i] + initial_loop_values, name="", - parallel_iterations=1, back_prop=False, swap_memory=False, - return_same_structure=True) - del self._outer_control_flow_context - - ctx.run_op = control_flow_ops.group(loop_result) - - # Convert the last_step_outputs from a list to the original dict structure - # of last_step_outputs. - last_step_tensor_outputs = loop_result[1:] - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access - return ctx - - def _call_for_each_replica(self, fn, args, kwargs): - strategy = self._container_strategy() - with ops.device(self._device), _OneDeviceReplicaContext(strategy): - return fn(*args, **kwargs) - - def _reduce_to(self, reduce_op, value, destinations): - del reduce_op, destinations - return value - - def _update(self, var, fn, args, kwargs, group): - # The implementations of _update() and _update_non_slot() are identical - # except _update() passes `var` as the first argument to `fn()`. - return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) - - def _update_non_slot(self, colocate_with, fn, args, kwargs, group): - del colocate_with - with ops.device(self._device), distribute_lib.UpdateContext(self._device): - result = fn(*args, **kwargs) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - def read_var(self, replica_local_var): - """Read the aggregate value of a replica-local variable.""" - return array_ops.identity(replica_local_var) - - def _unwrap(self, value): - return [value] - - def value_container(self, value): - return value - - @property - def _num_replicas_in_sync(self): - return 1 - - @property - def worker_devices(self): - return [self._device] - - @property - def parameter_devices(self): - return [self._device] - - def non_slot_devices(self, var_list): - del var_list - return [self._device] - - @property - def experimental_should_init(self): - return True - - @property - def should_checkpoint(self): - return True - - @property - def should_save_summary(self): - return True - - # TODO(priyag): Delete this once all strategies use global batch size. - @property - def _global_batch_size(self): - return True - - -class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): - """ReplicaContext for OneDeviceStrategy.""" - - def __init__(self, distribution_strategy): - distribute_lib.ReplicaContext.__init__( - self, - distribution_strategy, - replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) - - @property - def devices(self): - return [self._distribution_strategy.extended.worker_devices[0]] +OneDeviceStrategy = one_device_strategy.OneDeviceStrategy diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index d46cd6f529e363f76bfa2b22339add63530cfde8..906bffc8525688f63474c3f1fbc5d7f0a024431b 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -18,14 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import one_device_strategy from tensorflow.python.eager import test from tensorflow.python.framework import test_util -class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): +class OneDeviceStrategyTest( + strategy_test_lib.DistributionTestBase, + strategy_test_lib.OneDeviceDistributionTestBase): def _get_distribution_strategy(self): return one_device_strategy.OneDeviceStrategy("/device:CPU:0") @@ -44,7 +46,7 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) @test_util.run_in_graph_and_eager_modes - def testMakeInputFnIterator(self): + def testMakeInputFnIteratorWithDataset(self): d = one_device_strategy.OneDeviceStrategy("/device:CPU:0") dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i] for i in range(10)] @@ -57,6 +59,46 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): self._test_input_fn_iterator( iterator, d.extended.worker_devices, expected_values) + @test_util.run_in_graph_and_eager_modes + def testMakeInputFnIteratorWithCallable(self): + d = one_device_strategy.OneDeviceStrategy("/device:CPU:0") + def fn(): + dataset = dataset_ops.Dataset.range(10) + it = dataset.make_one_shot_iterator() + return it.get_next + expected_values = [[i] for i in range(10)] + input_fn = self._input_fn_to_test_input_context( + fn, + expected_num_replicas_in_sync=1, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = d.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, d.extended.worker_devices, expected_values, + test_reinitialize=False) + + @test_util.run_in_graph_and_eager_modes + def testNumpyIterator(self): + self._test_numpy_iterator(self._get_distribution_strategy()) + + def testAllReduceSum(self): + self._test_all_reduce_sum(self._get_distribution_strategy()) + + def testAllReduceSumGradients(self): + self._test_all_reduce_sum_gradients(self._get_distribution_strategy()) + + def testAllReduceSumGradientTape(self): + self._test_all_reduce_sum_gradient_tape(self._get_distribution_strategy()) + + def testAllReduceMean(self): + self._test_all_reduce_mean(self._get_distribution_strategy()) + + def testAllReduceMeanGradients(self): + self._test_all_reduce_mean_gradients(self._get_distribution_strategy()) + + def testAllReduceMeanGradientTape(self): + self._test_all_reduce_mean_gradient_tape(self._get_distribution_strategy()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index fa4705af7cb592119f56686d1f693a156f7b4b13..e388061b17a9b92dedbbf9839049b13c8575a22c 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -41,21 +41,17 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): with distribution.scope(): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - - ds = distribution.distribute_dataset(dataset_fn) - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() + iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) def run_step(): - return control_flow_ops.group(distribution.unwrap( - distribution.call_for_each_replica( - model_fn, args=(iterator.get_next(),)))) + return control_flow_ops.group( + distribution.unwrap( + distribution.extended.call_for_each_replica( + model_fn, args=(iterator.get_next(),)))) if not context.executing_eagerly(): with self.cached_session() as sess: - sess.run(iterator.initializer) + sess.run(iterator.initialize()) run_step = sess.make_callable(run_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index eaeb4d703015fc0762359b24dc23888c01e69111..e42bc50fdc4e5e93c998708b0790fdea7768faf2 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -18,34 +18,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib -from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib -from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.distribute import values -from tensorflow.python.eager import context -from tensorflow.python.framework import device as tf_device -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import device_setter -from tensorflow.python.util import nest +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import parameter_server_strategy +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver + +# pylint: disable=protected-access,invalid-name,line-too-long +CoreParameterServerStrategy = parameter_server_strategy.ParameterServerStrategy +CoreParameterServerExtended = parameter_server_strategy.ParameterServerStrategyExtended -_LOCAL_CPU = "/device:CPU:0" -_LOCAL_GPU_0 = "/device:GPU:0" +# pylint: enable=protected-access,invalid-name,line-too-long -# TODO(yuefengz): maybe cache variables on local CPU. -# TODO(yuefengz): we may want to set session options to disallow communication -# between workers. class ParameterServerStrategy(distribute_lib.DistributionStrategy): """A parameter server DistributionStrategy. + *** contrib version *** + This strategy class works for both local training and between-graph replicated training for multiple workers. If `cluster_spec` is specified, either passed in to __init__() method or parsed from the @@ -80,9 +70,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): variables. 3) It is also not recommended to open a colocation scope (i.e. calling - `tf.colocate_with`) under the strategy's scope. For colocating variables, - use `distribution.colocate_vars_with` instead. Colocation of ops will possibly - create conflicts of device assignment. + `tf.colocate_with`) under the strategy's scope. For colocating variables, use + `strategy.extended.colocate_vars_with` instead. Colocation of ops will + possibly create conflicts of device assignment. """ def __init__(self, num_gpus_per_worker=0): @@ -99,432 +89,84 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): super(ParameterServerStrategy, self).__init__( ParameterServerExtended(self, num_gpus_per_worker)) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def make_dataset_iterator(self, dataset): # pylint: disable=useless-super-delegation + """Makes an iterator for input provided via `dataset`. -class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): - """Implementation of ParameterServerStrategy.""" + NOTE: The batch size of the `dataset` argument is treated differently for + this contrib version of `ParameterServerStrategy`. - def __init__(self, container_strategy, num_gpus_per_worker): - super(ParameterServerExtended, self).__init__(container_strategy) - self._num_gpus_per_worker = num_gpus_per_worker - self._initialize_local(num_gpus_per_worker) + Data from the given dataset will be distributed evenly across all the + compute replicas. We will assume that the input dataset is batched by the + per-replica batch size. - # We typically don't need to do all-reduce in this strategy. - self._cross_device_ops = ( - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( - reduce_to_device=_LOCAL_CPU)) - - def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, - task_type, task_id): - """Initialize devices for multiple workers. - - It creates variable devices and compute devices. Variables and operations - will be assigned to them respectively. We have one compute device per - replica. The variable device is a device function or device string. The - default variable device assigns variables to parameter servers in a - round-robin fashion. + The user could also use `make_input_fn_iterator` if they want to + customize which input is fed to which replica/worker etc. Args: - num_gpus_per_worker: number of local GPUs or GPUs per worker. - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type. - task_id: the current task id. + dataset: `tf.data.Dataset` that will be distributed evenly across all + replicas. - Raises: - ValueError: if the cluster_spec doesn't have ps jobs. + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. """ - assert cluster_spec - if not task_type or task_id is None: - raise ValueError("When `cluster_spec` is given, you must also specify " - "`task_type` and `task_id`") - cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - - self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id) - - # Define compute devices which is a list of device strings and one for each - # replica. When there are GPUs, replicate operations on these GPUs. - # Otherwise, place operations on CPU. - if num_gpus_per_worker > 0: - self._compute_devices = [ - "%s/device:GPU:%d" % (self._worker_device, i) - for i in range(num_gpus_per_worker) - ] - else: - self._compute_devices = [self._worker_device] - - self._compute_devices = list( - map(device_util.resolve, self._compute_devices)) - self._canonical_compute_device_set = set(self._compute_devices) - - # In distributed mode, place variables on ps jobs in a round-robin fashion. - # Note that devices returned from `replica_device_setter` are not - # canonical and therefore we don't canonicalize all variable devices to - # make them consistent. - # TODO(yuefengz): support passing a strategy object to control variable - # assignment. - # TODO(yuefengz): merge the logic of replica_device_setter into this - # class. - num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) - if num_ps_replicas == 0: - raise ValueError("The cluster spec needs to have `ps` jobs.") - self._variable_device = device_setter.replica_device_setter( - ps_tasks=num_ps_replicas, - worker_device=self._worker_device, - merge_devices=True, - cluster=cluster_spec) - - # The `_parameter_devices` is needed for the `parameter_devices` property - # and is a list of all variable devices. Here parameter devices are all - # tasks of the "ps" job. - self._parameter_devices = map("/job:ps/task:{}".format, - range(num_ps_replicas)) - - # Add a default device so that ops without specified devices will not end up - # on other workers. - self._default_device = self._worker_device - - self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, - task_id) - self._cluster_spec = cluster_spec - self._task_type = task_type - self._task_id = task_id - - logging.info( - "Multi-worker ParameterServerStrategy with " - "cluster_spec = %r, task_type = %r, task_id = %r, " - "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, " - "variable_device = %r", cluster_spec.as_dict(), task_type, task_id, - num_ps_replicas, self._is_chief, self._compute_devices, - self._variable_device) - - def _initialize_local(self, num_gpus_per_worker): - """Initialize internal devices for local training.""" - self._worker_device = device_util.canonicalize("/device:CPU:0") - # Define compute devices which is a list of device strings and one for each - # replica. When there are GPUs, replicate operations on these GPUs. - # Otherwise, place operations on CPU. - if num_gpus_per_worker > 0: - self._compute_devices = list( - map("/device:GPU:{}".format, range(num_gpus_per_worker))) - else: - self._compute_devices = [_LOCAL_CPU] - - self._compute_devices = list( - map(device_util.resolve, self._compute_devices)) - self._canonical_compute_device_set = set(self._compute_devices) - - # If there is only one GPU, put everything on that GPU. Otherwise, place - # variables on CPU. - if num_gpus_per_worker == 1: - assert len(list(self._compute_devices)) == 1 - self._variable_device = _LOCAL_GPU_0 - self._parameter_devices = [_LOCAL_GPU_0] - else: - self._variable_device = _LOCAL_CPU - self._parameter_devices = [_LOCAL_CPU] - - self._is_chief = True - self._cluster_spec = None - self._task_type = None - self._task_id = None - - logging.info( - "ParameterServerStrategy with compute_devices = %r, " - "variable_device = %r", self._compute_devices, self._variable_device) - - def _distribute_dataset(self, dataset_fn): - """Distributes the dataset to each local GPU.""" - return values.PerReplicaDataset( - self._call_dataset_fn(dataset_fn), self._compute_devices, True) - - def _make_dataset_iterator(self, dataset): - worker_device_pairs = [(self._worker_device, self._compute_devices)] - return values.DatasetIterator(dataset, worker_device_pairs, - self._num_replicas_in_sync) - - def _make_input_fn_iterator( - self, - input_fn, - replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): - """Distributes the dataset to each local GPU.""" - if self._cluster_spec: - input_pipeline_id = multi_worker_util.id_in_cluster( - self._cluster_spec, self._task_type, self._task_id) - num_input_pipelines = multi_worker_util.worker_count( - self._cluster_spec, self._task_type) - else: - input_pipeline_id = 0 - num_input_pipelines = 1 - input_context = distribute_lib.InputContext( - num_input_pipelines=num_input_pipelines, - input_pipeline_id=input_pipeline_id, - num_replicas_in_sync=self._num_replicas_in_sync) - worker_device_pairs = [(self._worker_device, self._compute_devices)] - return values.InputFunctionIterator( - input_fn, worker_device_pairs, [input_context]) - - def _broadcast_to(self, tensor, destinations): - # This is both a fast path for Python constants, and a way to delay - # converting Python values to a tensor until we know what type it - # should be converted to. Otherwise we have trouble with: - # global_step.assign_add(1) - # since the `1` gets broadcast as an int32 but global_step is int64. - if isinstance(tensor, (float, int)): - return tensor - if not cross_device_ops_lib.check_destinations(destinations): - destinations = self._compute_devices - return self._cross_device_ops.broadcast(tensor, destinations) - - def _allow_variable_partition(self): - return not context.executing_eagerly() - - # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through - # this creator, such as "MutableHashTable". - def _create_variable(self, next_creator, *args, **kwargs): - if self._num_replicas_in_sync > 1: - aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) - if aggregation not in ( - vs.VariableAggregation.NONE, - vs.VariableAggregation.SUM, - vs.VariableAggregation.MEAN, - vs.VariableAggregation.ONLY_FIRST_REPLICA - ): - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) - - def var_creator(*args, **kwargs): - """Create an AggregatingVariable and fix up collections.""" - # Record what collections this variable should be added to. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - - # Create and wrap the variable. - v = next_creator(*args, **kwargs) - wrapped = values.AggregatingVariable(v, aggregation) - - # Add the wrapped variable to the requested collections. - # The handling of eager mode and the global step matches - # ResourceVariable._init_from_args(). - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the contained - # variable to the TRAINABLE_VARIABLES collection, so we manually - # remove it and replace with the wrapper. We can't set "trainable" - # to False for next_creator() since that causes functions like - # implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - l.remove(v) - g.add_to_collections(collections, wrapped) - elif ops.GraphKeys.GLOBAL_STEP in collections: - ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) - - return wrapped - else: - var_creator = next_creator - - if "colocate_with" in kwargs: - with ops.device(None): - with ops.colocate_with(kwargs["colocate_with"]): - return var_creator(*args, **kwargs) - - with ops.colocate_with(None, ignore_existing=True): - with ops.device(self._variable_device): - return var_creator(*args, **kwargs) - - def _call_for_each_replica(self, fn, args, kwargs): - # pylint: disable=protected-access - return mirrored_strategy._call_for_each_replica( - self._container_strategy(), fn, args, kwargs) + return super(ParameterServerStrategy, self).make_dataset_iterator(dataset) - def _verify_destinations_not_different_worker(self, destinations): - if not self._cluster_spec: - return - if destinations is None: - return - for d in cross_device_ops_lib.get_devices_from(destinations): - d_spec = tf_device.DeviceSpec.from_string(d) - if d_spec.job == self._task_type and d_spec.task != self._task_id: - raise ValueError( - "Cannot reduce to another worker: %r, current worker is %r" % - (d, self._worker_device)) + # Override to change the documentation to reflect the different handling of + # global vs. local batch size between core and contrib. + def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation + self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): + """Makes an iterator for input provided via a nest of numpy arrays. - def _reduce_to(self, reduce_op, value, destinations): - self._verify_destinations_not_different_worker(destinations) - if not isinstance(value, values.DistributedValues): - # pylint: disable=protected-access - return mirrored_strategy._reduce_non_distributed_value( - self, reduce_op, value, destinations) - return self._cross_device_ops.reduce( - reduce_op, value, destinations=destinations) - - def _batch_reduce_to(self, reduce_op, value_destination_pairs): - for _, destinations in value_destination_pairs: - self._verify_destinations_not_different_worker(destinations) - return self._cross_device_ops.batch_reduce(reduce_op, - value_destination_pairs) - - def _select_single_value(self, structured): - """Select any single values in `structured`.""" - - def _select_fn(x): # pylint: disable=g-missing-docstring - if isinstance(x, values.Mirrored): - if len(x.devices) == 1: - return list(x._index.values())[0] # pylint: disable=protected-access - else: - raise ValueError( - "You cannot update variable with a Mirrored object with multiple " - "components %r when using ParameterServerStrategy. You must " - "specify a single value or a Mirrored with a single value." % x) - elif isinstance(x, values.PerReplica): - raise ValueError( - "You cannot update variable with a PerReplica object %r when using " - "ParameterServerStrategy. You must specify a single value or a " - "Mirrored with a single value" % x) - else: - return x - - return nest.map_structure(_select_fn, structured) - - def _update(self, var, fn, args, kwargs, group): - if isinstance(var, values.AggregatingVariable): - var = var.get() - if not isinstance(var, resource_variable_ops.ResourceVariable): - raise ValueError( - "You can not update `var` %r. It must be a Variable." % var) - with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): - result = fn(var, *self._select_single_value(args), - **self._select_single_value(kwargs)) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - # TODO(yuefengz): does it need to call _select_single_value? - def _update_non_slot(self, colocate_with, fn, args, kwargs, group): - with ops.device( - colocate_with.device), distribute_lib.UpdateContext(colocate_with): - result = fn(*args, **kwargs) - if group: - return result - else: - return nest.map_structure(self._unwrap, result) - - def _unwrap(self, val): - if isinstance(val, values.DistributedValues): - # Return in a deterministic order. - if set(val.devices) == self._canonical_compute_device_set: - return [val.get(device=d) for d in self._compute_devices] - return [val.get(device=d) for d in sorted(val.devices)] - return [val] - - def value_container(self, val): - if (hasattr(val, "_aggregating_container") and - not isinstance(val, values.AggregatingVariable)): - wrapper = val._aggregating_container() # pylint: disable=protected-access - if wrapper is not None: - return wrapper - return val - - def read_var(self, var): - # No need to distinguish between normal variables and replica-local - # variables. - return array_ops.identity(var) - - def _configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - """Configures the strategy class. - - The strategy object will be re-initialized if `cluster_spec` is given but - was not passed in the constructor. + NOTE: The `batch_size` argument here has different behavior for this + contrib version of `ParameterServerStrategy`. Args: - session_config: not used currently. - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type. - task_id: the current task id. - - Raises: - ValueError: if `cluster_spec` is given but `task_type` or `task_id` is - not. + numpy_input: A nest of NumPy input arrays that will be distributed evenly + across all replicas. + batch_size: The number of entries from the array we should consume in one + step of the computation, across all replicas. This is the per-replica + batch size. The global batch size will be this times + `num_replicas_in_sync`. + num_epochs: The number of times to iterate through the examples. A value + of `None` means repeat forever. + shuffle: Size of buffer to use for shuffling the input examples. + Use `None` to disable shuffling. + session: (TensorFlow v1.x graph execution only) A session used for + initialization. + + Returns: + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. """ - if not self._cluster_spec and cluster_spec: - # If a `cluster_spec` is already passed in, do nothing here. - # TODO(yuefengz): check `cluster_spec` is the same if this object has - # already been initialized with a `cluster_spec`. - if task_type is None or task_id is None: - raise ValueError("When `cluster_spec` is given, must also specify " - "`task_type` and `task_id`.") - self._cluster_spec = multi_worker_util.normalize_cluster_spec( - cluster_spec) - self._task_type = task_type - self._task_id = task_id - self._initialize_multi_worker(self._num_gpus_per_worker, - self._cluster_spec, task_type, task_id) - - if session_config: - session_config.CopyFrom(self._update_config_proto(session_config)) - - def _update_config_proto(self, config_proto): - updated_config = copy.deepcopy(config_proto) - if not self._cluster_spec: - updated_config.isolate_session_state = True - return updated_config + return super(ParameterServerStrategy, + self).experimental_make_numpy_iterator( + numpy_input, batch_size, num_epochs, shuffle, session) - updated_config.isolate_session_state = False - assert self._task_type - assert self._task_id is not None - - # The device filters prevent communication between workers. - if self._task_type not in ["chief", "worker"]: - return updated_config - del updated_config.device_filters[:] - updated_config.device_filters.extend( - ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) - return updated_config - - @property - def _num_replicas_in_sync(self): - return len(self._compute_devices) - - @property - def worker_devices(self): - # Make a copy to prevent users from accidentally mutating our copy. - return list(self._compute_devices) - - @property - def parameter_devices(self): - return list(self._parameter_devices) - - def non_slot_devices(self, var_list): - return min(var_list, key=lambda x: x.name) - - @property - def experimental_between_graph(self): - # TODO(yuefengz): Should this return False in the local case? - return True - - @property - def experimental_should_init(self): - return self._is_chief +class ParameterServerExtended(CoreParameterServerExtended): + """Implementation of ParameterServerStrategy.""" - @property - def should_checkpoint(self): - return self._is_chief + def __init__(self, container_strategy, num_gpus_per_worker): + # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change + # the constructor's interface to allow customized cluster resolver. Use + # SimpleClusterResolver to override num_accelerators. + tfconfig = TFConfigClusterResolver() + cluster_resolver = SimpleClusterResolver( + cluster_spec=tfconfig.cluster_spec(), + task_type=tfconfig.task_type, + task_id=tfconfig.task_id, + num_accelerators=num_gpus_per_worker) + super(ParameterServerExtended, self).__init__( + container_strategy, cluster_resolver=cluster_resolver) - @property - def should_save_summary(self): - return self._is_chief + def _make_dataset_iterator(self, dataset): + return input_lib.DatasetIterator(dataset, self._input_workers) # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """The contrib version of PS strategy uses per-replica batch size.""" return False diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 83d7473666a65e438a1c0119d2a12bf54e53c8fc..fede253d13804087476fef8b7211a6bfe5789906 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -29,10 +29,13 @@ from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import parameter_server_strategy as core_parameter_server_strategy from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.estimator import run_config @@ -45,10 +48,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import training_util +from tensorflow.python.training.server_lib import ClusterSpec CHIEF = run_config.TaskType.CHIEF WORKER = run_config.TaskType.WORKER @@ -62,6 +67,57 @@ def _get_replica_id_integer(): return replica_id +class MockCoreParameterServerStrategy(distribute_lib.DistributionStrategy): + """Mock the strategy to allow cluster resolver as an argument.""" + + def __init__(self, cluster_resolver): + super(MockCoreParameterServerStrategy, self).__init__( + core_parameter_server_strategy.ParameterServerStrategyExtended( + self, cluster_resolver=cluster_resolver)) + + +def create_test_objects(cluster_spec=None, + task_type=None, + task_id=None, + num_gpus=None, + sess_config=None, + use_core_strategy=False): + sess_config = sess_config or config_pb2.ConfigProto() + if num_gpus is None: + num_gpus = context.num_gpus() + if use_core_strategy: + if cluster_spec and task_type and task_id is not None: + cluster_resolver = SimpleClusterResolver( + cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), + task_type=task_type, + task_id=task_id, + num_accelerators=num_gpus) + target = 'grpc://' + cluster_spec[WORKER][task_id] + else: + cluster_resolver = SimpleClusterResolver( + ClusterSpec({}), num_accelerators=num_gpus) + target = '' + + distribution = MockCoreParameterServerStrategy(cluster_resolver) + sess_config = copy.deepcopy(sess_config) + sess_config = distribution.update_config_proto(sess_config) + else: + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=num_gpus) + if task_type: + sess_config = copy.deepcopy(sess_config) + distribution.configure( + session_config=sess_config, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id) + target = 'grpc://' + cluster_spec[WORKER][task_id] + else: + target = '' + + return distribution, target, sess_config + + class ParameterServerStrategyTestBase( multi_worker_test_base.MultiWorkerTestBase): @@ -75,24 +131,27 @@ class ParameterServerStrategyTestBase( self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True) super(ParameterServerStrategyTestBase, self).setUp() - def _get_test_objects(self, task_type, task_id, num_gpus): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=num_gpus) - if not task_type: - return distribution, '', self._sess_config - - sess_config = copy.deepcopy(self._sess_config) - distribution.configure( - session_config=sess_config, + def _get_test_objects(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): + return create_test_objects( cluster_spec=self._cluster_spec, task_type=task_type, - task_id=task_id) - return (distribution, 'grpc://' + self._cluster_spec[WORKER][task_id], - sess_config) - - def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): + task_id=task_id, + num_gpus=num_gpus, + sess_config=self._sess_config, + use_core_strategy=use_core_strategy) + + def _test_device_assignment_distributed(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) - d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) + d, _, sess_config = self._get_test_objects( + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) with ops.Graph().as_default(), \ self.cached_session(target=self._default_target, config=sess_config) as sess, \ @@ -131,7 +190,7 @@ class ParameterServerStrategyTestBase( '/job:worker/replica:0/task:0/%s' % last_part_device) # The colocate_vars_with can override the distribution's device. - with d.colocate_vars_with(x): + with d.extended.colocate_vars_with(x): y = variable_scope.get_variable( 'y', initializer=20.0, aggregation=variable_scope.VariableAggregation.SUM) @@ -177,7 +236,7 @@ class ParameterServerStrategyTestBase( self.assertIn('/job:ps/', h.device) return y_add, z_add, f - y, z, f = d.call_for_each_replica(model_fn) + y, z, f = d.extended.call_for_each_replica(model_fn) self.assertNotEqual(y, None) self.assertNotEqual(z, None) self.assertNotEqual(f, None) @@ -190,9 +249,10 @@ class ParameterServerStrategyTestBase( self.assertEqual(f_val, 46.0) def _test_device_assignment_distributed_enable_partitioner( - self, task_type, task_id, num_gpus): - d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) - num_shards = len(d.parameter_devices) + self, task_type, task_id, num_gpus, use_core_strategy=False): + d, _, sess_config = self._get_test_objects( + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) + num_shards = len(d.extended.parameter_devices) partitioner = partitioned_variables.fixed_size_partitioner(num_shards) with ops.Graph().as_default(), \ self.cached_session(target=self._default_target, @@ -224,39 +284,18 @@ class ParameterServerStrategyTestBase( self.assertEqual(var.device, '/job:ps/task:%d' % part_id) self.assertEqual(var.device, x_add[part_id].device) - # The colocate_vars_with can override the distribution's device. - with d.colocate_vars_with(x_add[0]): - y = variable_scope.get_variable( - 'y', - initializer=constant_op.constant([20.0, 10.0]), - aggregation=variable_scope.VariableAggregation.SUM, - partitioner=partitioner) - y_add = y.assign_add( - [array_ops.identity(x_add[0]), - array_ops.identity(x_add[1])]) - - for part_id, var in enumerate(y): - self.assertEqual(var.device, '/job:ps/task:0') - self.assertEqual(y_add[part_id].device, var.device) - self.assertEqual(var.device, x_add[0].device) - - return x_add, y_add + return x_add - x, y = d.call_for_each_replica(model_fn) + x = d.extended.call_for_each_replica(model_fn) if context.num_gpus() >= 1: variables.global_variables_initializer().run() - x_val, y_val = sess.run([x, y]) + x_val = sess.run(x) if num_gpus < 1: self.assertEqual(x_val, [13.0, 25.0]) - self.assertEqual(y_val, [33.0, 35.0]) else: x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus] - y_expect = [ - 20.0 + x_expect[0] * num_gpus, 10.0 + x_expect[1] * num_gpus - ] self.assertEqual(x_val, x_expect) - self.assertEqual(y_val, y_expect) def _test_device_assignment_local(self, d, @@ -305,7 +344,7 @@ class ParameterServerStrategyTestBase( self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2')) # The colocate_vars_with can override the distribution's device. - with d.colocate_vars_with(x): + with d.extended.colocate_vars_with(x): y = variable_scope.get_variable( 'y', initializer=20.0, aggregation=variable_scope.VariableAggregation.SUM) @@ -348,7 +387,7 @@ class ParameterServerStrategyTestBase( device_util.canonicalize(h.device)) return y_add, z_add, f - y, z, f = d.call_for_each_replica(model_fn) + y, z, f = d.extended.call_for_each_replica(model_fn) self.assertNotEqual(y, None) self.assertNotEqual(z, None) self.assertNotEqual(f, None) @@ -360,9 +399,13 @@ class ParameterServerStrategyTestBase( self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) - def _test_simple_increment(self, task_type, task_id, num_gpus): + def _test_simple_increment(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, sess_config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) if d.extended._cluster_spec: num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) if 'chief' in d.extended._cluster_spec.as_dict(): @@ -395,7 +438,7 @@ class ParameterServerStrategyTestBase( train_op = control_flow_ops.group(x_add, y_add, z_add) return x, y, z, train_op - x, y, z, train_op = d.call_for_each_replica(model_fn) + x, y, z, train_op = d.extended.call_for_each_replica(model_fn) train_op = d.group(train_op) if context.num_gpus() < d.extended._num_gpus_per_worker: @@ -430,9 +473,13 @@ class ParameterServerStrategyTestBase( y_val == 20.0 + 1.0 * num_workers * d.num_replicas_in_sync and z_val == 30.0 + 1.0 * num_workers) - def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + def _test_minimize_loss_graph(self, + task_type, + task_id, + num_gpus, + use_core_strategy=False): d, master_target, sess_config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) if task_type: # Multi-worker assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec @@ -472,20 +519,20 @@ class ParameterServerStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=(one,)) + g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( - d.update(v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + d.extended.update(v, update, args=(g,), group=False)): + after_list.append(d.extended.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -518,10 +565,16 @@ class ParameterServerStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before - def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, - expected_values): + def _test_input_fn_iterator(self, + task_type, + task_id, + num_gpus, + input_fn, + expected_values, + test_reinitialize=True, + use_core_strategy=False): distribution, master_target, config = self._get_test_objects( - task_type, task_id, num_gpus) + task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) devices = distribution.extended.worker_devices with ops.Graph().as_default(), \ @@ -532,27 +585,31 @@ class ParameterServerStrategyTestBase( for expected_value in expected_values: next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - sess.run([values.select_device(d, next_element) for d in devices]) + sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. - sess.run(iterator.initialize()) + if test_reinitialize: + sess.run(iterator.initialize()) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = sess.run( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run([values.select_replica(r, next_element) + for r in range(len(devices))]) + self.assertEqual(expected_value, computed_value) -class ParameterServerStrategyTest(ParameterServerStrategyTestBase, - strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class ParameterServerStrategyTest( + ParameterServerStrategyTestBase, + strategy_test_lib.DistributionTestBase, + strategy_test_lib.TwoDeviceDistributionTestBase, + parameterized.TestCase): @classmethod def setUpClass(cls): @@ -560,111 +617,173 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, num_workers=3, num_ps=2) cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0] - def test_num_replicas_in_sync(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def test_num_replicas_in_sync(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) # All the devices on a given worker are in sync which in this case is the # number of gpus on each worker. - self.assertEqual(2, distribution.num_replicas_in_sync) + self.assertEqual(2, strategy.num_replicas_in_sync) - def testDeviceAssignmentLocalCPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=0) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalCPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=0, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) + strategy, compute_device='CPU', variable_device='CPU', num_gpus=0) - def testDeviceAssignmentLocalOneGPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=1) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalOneGPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=1, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) + strategy, compute_device='GPU', variable_device='GPU', num_gpus=1) - def testDeviceAssignmentLocalTwoGPUs(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testDeviceAssignmentLocalTwoGPUs(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) + strategy, compute_device='GPU', variable_device='CPU', num_gpus=2) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testDeviceAssignmentDistributed(self, num_gpus): - self._test_device_assignment_distributed('worker', 1, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testDeviceAssignmentDistributed(self, num_gpus, use_core_strategy): + self._test_device_assignment_distributed( + 'worker', 1, num_gpus, use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus): + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus, + use_core_strategy): self._test_device_assignment_distributed_enable_partitioner( - 'worker', 1, num_gpus) + 'worker', 1, num_gpus, use_core_strategy=use_core_strategy) - def testSimpleBetweenGraph(self): - self._run_between_graph_clients(self._test_simple_increment, - self._cluster_spec, context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testSimpleBetweenGraph(self, use_core_strategy): + self._run_between_graph_clients( + self._test_simple_increment, + self._cluster_spec, + context.num_gpus(), + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testLocalSimpleIncrement(self, num_gpus): - self._test_simple_increment(None, 0, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testLocalSimpleIncrement(self, num_gpus, use_core_strategy): + self._test_simple_increment(None, 0, num_gpus, use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraphDistributed(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraphDistributed(self, num_gpus, use_core_strategy): + self._run_between_graph_clients( + self._test_minimize_loss_graph, + self._cluster_spec, + num_gpus, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraphLocal(self, num_gpus): - self._test_minimize_loss_graph(None, None, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraphLocal(self, num_gpus, use_core_strategy): + self._test_minimize_loss_graph(None, None, num_gpus, use_core_strategy) # TODO(priyag): Refactor this and other multi worker tests. @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) - def testMakeInputFnIteratorDistributed(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[1, 2], + required_gpus=1, + use_core_strategy=[True, False], + use_dataset=[True, False])) + def testMakeInputFnIteratorDistributed(self, num_gpus, use_core_strategy, + use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(100) + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(100) + else: + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next expected_values = [[i+j for j in range(num_gpus)] for i in range(0, 100, num_gpus)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=3, expected_input_pipeline_id=1) # because task_id = 1 - self._test_input_fn_iterator('worker', 1, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + 'worker', + 1, + num_gpus, + input_fn, + expected_values, + test_reinitialize=use_dataset, + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) - def testMakeInputFnIteratorLocal(self, num_gpus): + combinations.combine( + mode=['graph'], + num_gpus=[1, 2], + required_gpus=1, + use_core_strategy=[True, False], + use_dataset=[True, False])) + def testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy, + use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - dataset_fn = lambda: dataset_ops.Dataset.range(100) + if use_dataset: + fn = lambda: dataset_ops.Dataset.range(100) + else: + def fn(): + dataset = dataset_ops.Dataset.range(100) + it = dataset.make_one_shot_iterator() + return it.get_next expected_values = [[i+j for j in range(num_gpus)] for i in range(0, 100, num_gpus)] input_fn = self._input_fn_to_test_input_context( - dataset_fn, + fn, expected_num_replicas_in_sync=num_gpus, expected_num_input_pipelines=1, expected_input_pipeline_id=0) # only one worker and pipeline for local. - self._test_input_fn_iterator(None, None, num_gpus, - input_fn, expected_values) + self._test_input_fn_iterator( + None, + None, + num_gpus, + input_fn, + expected_values, + test_reinitialize=use_dataset, + use_core_strategy=use_core_strategy) - def testGlobalStepUpdate(self): - strategy = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepUpdate(self, use_core_strategy): + strategy, _, _ = create_test_objects(use_core_strategy=use_core_strategy) self._test_global_step_update(strategy) - def testUpdateConfigProtoMultiWorker(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - distribution.configure( + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testUpdateConfigProtoMultiWorker(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + strategy.configure( cluster_spec=self._cluster_spec, task_type='worker', task_id=1) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) - new_config = distribution.update_config_proto(config_proto) + new_config = strategy.update_config_proto(config_proto) # Verify device filters. self.assertEqual(['/job:worker/task:1', '/job:ps'], @@ -673,16 +792,48 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, # Verify isolate_session_state self.assertFalse(new_config.isolate_session_state) - def testUpdateConfigProtoLocal(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testUpdateConfigProtoLocal(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) config_proto = config_pb2.ConfigProto() - new_config = distribution.update_config_proto(config_proto) + new_config = strategy.update_config_proto(config_proto) # Verify isolate_session_state self.assertTrue(new_config.isolate_session_state) + def testAllReduceSum(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum(distribution) + + def testAllReduceSumGradients(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum_gradients(distribution) + + def testAllReduceSumGradientTape(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_sum_gradient_tape(distribution) + + def testAllReduceMean(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean(distribution) + + def testAllReduceMeanGradients(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean_gradients(distribution) + + def testAllReduceMeanGradientTape(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_all_reduce_mean_gradient_tape(distribution) + class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, parameterized.TestCase): @@ -693,20 +844,31 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, num_workers=3, num_ps=2, has_chief=True) cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0] - def testSimpleBetweenGraph(self): - self._run_between_graph_clients(self._test_simple_increment, - self._cluster_spec, context.num_gpus()) + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testSimpleBetweenGraph(self, use_core_strategy): + self._run_between_graph_clients( + self._test_simple_increment, + self._cluster_spec, + context.num_gpus(), + use_core_strategy=use_core_strategy) @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraph(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) + combinations.combine( + mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) + def testMinimizeLossGraph(self, num_gpus, use_core_strategy): + self._run_between_graph_clients( + self._test_minimize_loss_graph, + self._cluster_spec, + num_gpus, + use_core_strategy=use_core_strategy) - def testGlobalStepIsWrapped(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - with ops.Graph().as_default(), distribution.scope(): + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepIsWrappedOnTwoGPUs(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): created_step = training_util.create_global_step() get_step = training_util.get_global_step() self.assertEqual(created_step, get_step, @@ -715,19 +877,55 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, id(get_step), get_step.__class__.__name__))) self.assertIs(values.AggregatingVariable, type(created_step)) self.assertIs(values.AggregatingVariable, type(get_step)) + self.assertIs(strategy, created_step.distribute_strategy) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testGlobalStepIsNotWrappedOnOneGPU(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=1, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): + created_step = training_util.create_global_step() + get_step = training_util.get_global_step() + self.assertEqual(created_step, get_step, + msg=('created_step %s type %s vs. get_step %s type %s' % + (id(created_step), created_step.__class__.__name__, + id(get_step), get_step.__class__.__name__))) + self.assertIs(resource_variable_ops.ResourceVariable, type(created_step)) + self.assertIs(resource_variable_ops.ResourceVariable, type(get_step)) + # All variables have an _distribute_strategy parameter. Only variable + # subclasses in distribution strategy expose it publicly. + self.assertFalse(hasattr(strategy, 'distribute_strategy')) + self.assertIs(strategy, created_step._distribute_strategy) + + @combinations.generate( + combinations.combine(mode=['graph'], use_core_strategy=[True, False])) + def testValueContainer(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + with ops.Graph().as_default(), strategy.scope(): - def testValueContainer(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - with ops.Graph().as_default(), distribution.scope(): def f(): with backprop.GradientTape() as tape: v = variable_scope.get_variable('v', initializer=10.0) _ = v * v v, = tape.watched_variables() - w = distribution.extended.value_container(v) + w = strategy.extended.value_container(v) self.assertIs(values.AggregatingVariable, type(w)) - distribution.extended.call_for_each_replica(f) + + strategy.extended.call_for_each_replica(f) + + +class LocalParameterServerStrategyTest(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine(mode=['graph', 'eager'], + use_core_strategy=[True, False], + required_gpus=2)) + def testNumpyIterator(self, use_core_strategy): + strategy, _, _ = create_test_objects( + num_gpus=2, use_core_strategy=use_core_strategy) + self._test_numpy_iterator(strategy) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index c928b6d9f1f21508edd753f94c38ab2723cc0a9f..27aad46b97195aa498d0382f08c04c312cebbe65 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import backprop -from tensorflow.python.eager import context from tensorflow.python.training import optimizer as optimizer_lib @@ -33,6 +32,9 @@ class Step(object): def distribution(self): return self._distribution + def initialize(self): + return [] + def __call__(self): """Perform one step of this training algorithm.""" raise NotImplementedError("must be implemented in descendants") @@ -50,12 +52,10 @@ class StandardInputStep(Step): def __init__(self, dataset_fn, distribution): super(StandardInputStep, self).__init__(distribution) - self._distributed_input = distribution.distribute_dataset(dataset_fn) - if context.executing_eagerly(): - self._iterator = self._distributed_input.make_one_shot_iterator() - else: - # TODO(priyag): Expose initializer via some initializer property. - self._iterator = self._distributed_input.make_initializable_iterator() + self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) + + def initialize(self): + return self._iterator.initialize() class StandardSingleLossStep(StandardInputStep): @@ -99,8 +99,8 @@ class StandardSingleLossStep(StandardInputStep): gradients_fn = backprop.implicit_grad(self._loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) - grads_and_vars = self.distribution.call_for_each_replica( - gradients_fn, args=(ctx,) + inputs) + grads_and_vars = self.distribution.extended.call_for_each_replica( + gradients_fn, args=(ctx, inputs)) # If threads use layers, then we need to run the first step # sequentially, so that layers.build() is not executed in parallel. # Otherwise, multiple sets of mirrored variables are going to be @@ -109,6 +109,6 @@ class StandardSingleLossStep(StandardInputStep): self.distribution, grads_and_vars) # TODO(priyag): Return the outputs, context, etc as well. - ctx = self.distribution.run_steps_on_dataset( + ctx = self.distribution.extended.experimental_run_steps_on_iterator( step_fn, self._iterator, self._iterations_per_step) return ctx.run_op diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 1ff9b9ceec13351b098d47ed3ff62f689a625a31..9f48560b2666036e149a63c98b6529fb24cc5067 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -45,24 +45,21 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): single_loss_step, layer = single_loss_example( optimizer_fn, distribution, use_bias=True, iterations_per_step=2) - self.evaluate(distribution.initialize()) if context.executing_eagerly(): + single_loss_step.initialize() run_step = single_loss_step else: with self.cached_session() as sess: - sess.run(single_loss_step._iterator.initializer) + sess.run(single_loss_step.initialize()) run_step = sess.make_callable(single_loss_step()) self.evaluate(variables.global_variables_initializer()) weights, biases = [], [] for _ in range(5): run_step() - weights.append(self.evaluate(layer.kernel)) biases.append(self.evaluate(layer.bias)) - self.evaluate(distribution.finalize()) - error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index d50b142c5e9ad36522b11a77219140a7b40d9bf6..90f552eda4c41742f21ca276d8a059b2b102554f 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values @@ -31,6 +34,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -41,25 +45,26 @@ class _TestException(Exception): pass -# May be the argument to either distribution.call_for_each_replica() or +# May be the argument to either distribution.extended.call_for_each_replica() or # get_replica_context().merge_call() def _raise_exception_fn(_=None): raise _TestException() -# Must be the argument to a distribution.call_for_each_replica() call, calls a -# get_replica_context().merge_call() that raises an exception. +# Must be the argument to a distribution.extended.call_for_each_replica() call, +# calls a get_replica_context().merge_call() that raises an exception. def _merge_raises_fn(): ds_context.get_replica_context().merge_call(_raise_exception_fn) # Must be the argument to a get_replica_context().merge_call() call, calls -# dist.call_for_each_replica() with a function that raises an exception. +# dist.extended.call_for_each_replica() with a function that raises an +# exception. def _call_raises_fn(dist): - dist.call_for_each_replica(_raise_exception_fn) + dist.extended.call_for_each_replica(_raise_exception_fn) -# Must be the argument to a distribution.call_for_each_replica() call, +# Must be the argument to a distribution.extended.call_for_each_replica() call, # calls a get_replica_context().merge_call() that calls a # call_for_each_replica() that raises an exception. def _merge_call_raises_fn(): @@ -67,15 +72,16 @@ def _merge_call_raises_fn(): # Must be the argument to a get_replica_context().merge_call() call, calls -# dist.call_for_each_replica() with a function that calls a +# dist.extended.call_for_each_replica() with a function that calls a # get_replica_context().merge_call() that raises an exception. def _call_merge_raises_fn(dist): - dist.call_for_each_replica(_merge_raises_fn) + dist.extended.call_for_each_replica(_merge_raises_fn) -# Must be the argument to a distribution.call_for_each_replica() call, calls a -# get_replica_context().merge_call() that calls a call_for_each_replica() that -# calls a get_replica_context().merge_call() that raises an exception. +# Must be the argument to a distribution.extended.call_for_each_replica() call, +# calls a get_replica_context().merge_call() that calls a +# call_for_each_replica() that calls a get_replica_context().merge_call() that +# raises an exception. def _merge_call_merge_raises_fn(): ds_context.get_replica_context().merge_call(_call_merge_raises_fn) @@ -106,21 +112,21 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=(one,)) + g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) - with ops.control_dependencies(d.update( - v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + with ops.control_dependencies(d.extended.update( + v, update, args=(g,), group=False)): + after_list.append(d.extended.read_var(v)) return before_list, after_list for i in range(10): @@ -162,20 +168,20 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, args=(one,)) + g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] for g, v in g_v: - fetched = d.read_var(v) + fetched = d.extended.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): g = d.extended.reduce_to( reduce_util.ReduceOp.SUM, g, destinations=v) - with ops.control_dependencies(d.update( - v, update, g, grouped=False)): - after_list.append(d.read_var(v)) + with ops.control_dependencies(d.extended.update( + v, update, args=(g,), group=False)): + after_list.append(d.extended.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -202,23 +208,23 @@ class DistributionTestBase(test.TestCase): self.assertFalse(expected_devices[replica_id]) expected_devices[replica_id] = True - d.call_for_each_replica(mark_devices_fn) + d.extended.call_for_each_replica(mark_devices_fn) self.assertAllEqual(expected_devices, [True] * len(d.extended.worker_devices)) def _test_call_and_merge_exceptions(self, dist): with dist.scope(): with self.assertRaises(_TestException): - dist.call_for_each_replica(_raise_exception_fn) + dist.extended.call_for_each_replica(_raise_exception_fn) with self.assertRaises(_TestException): - dist.call_for_each_replica(_merge_raises_fn) + dist.extended.call_for_each_replica(_merge_raises_fn) with self.assertRaises(_TestException): - dist.call_for_each_replica(_merge_call_raises_fn) + dist.extended.call_for_each_replica(_merge_call_raises_fn) with self.assertRaises(_TestException): - dist.call_for_each_replica(_merge_call_merge_raises_fn) + dist.extended.call_for_each_replica(_merge_call_merge_raises_fn) def _input_fn_to_test_input_context(self, - dataset_fn, + dataset_or_callable_fn, expected_num_replicas_in_sync, expected_num_input_pipelines, expected_input_pipeline_id): @@ -242,33 +248,35 @@ class DistributionTestBase(test.TestCase): self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id) worker_id_counter[0] += 1 - return dataset_fn() + return dataset_or_callable_fn() return _input_fn def _test_input_fn_iterator(self, iterator, devices, expected_values, - sess=None): + sess=None, test_reinitialize=True): evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) evaluate(iterator.initialize()) for expected_value in expected_values: next_element = iterator.get_next() computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) + [values.select_replica(r, next_element) for r in range(len(devices))]) self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() - evaluate([values.select_device(d, next_element) for d in devices]) + evaluate( + [values.select_replica(r, next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. - evaluate(iterator.initialize()) + if test_reinitialize: + evaluate(iterator.initialize()) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate([values.select_replica(r, next_element) + for r in range(len(devices))]) + self.assertEqual(expected_value, computed_value) def _test_global_step_update(self, strategy): with strategy.scope(): @@ -286,8 +294,195 @@ class DistributionTestBase(test.TestCase): value = global_step.read_value() return train_op, value - train_ops, value = strategy.call_for_each_replica(model_fn) + train_ops, value = strategy.extended.call_for_each_replica(model_fn) self.evaluate(strategy.group(train_ops)) global_step_tensors = strategy.unwrap(value) global_step_values = self.evaluate(global_step_tensors) - self.assertEqual([1] * len(global_step_tensors), global_step_values) + self.assertEqual((1,) * len(global_step_tensors), global_step_values) + + def _test_numpy_iterator(self, strategy): + with strategy.scope(), self.cached_session() as sess: + x = np.asarray([[1, 2], [6, 12], [2, 4], + [5, 10], [3, 6], [4, 8]]) + y = np.asarray([5, 4, 3, 2, 1, 0]) + batch_size = 6 + if not strategy.extended._global_batch_size: # pylint: disable=protected-access + batch_size = batch_size // strategy.num_replicas_in_sync + i = strategy.experimental_make_numpy_iterator( + (x, y), batch_size=batch_size, num_epochs=2, shuffle=None, + session=sess) + self.evaluate(i.initialize()) + + def run_and_concatenate(strategy, i): + x, y = strategy.experimental_run(lambda z: z, i) + x, y = self.evaluate((strategy.unwrap(x), strategy.unwrap(y))) + return np.concatenate(x), np.concatenate(y) + + x_1, y_1 = run_and_concatenate(strategy, i) + self.assertAllEqual(x, x_1) + self.assertAllEqual(y, y_1) + x_2, y_2 = run_and_concatenate(strategy, i) + self.assertAllEqual(x, x_2) + self.assertAllEqual(y, y_2) + with self.assertRaises(errors.OutOfRangeError): + run_and_concatenate(strategy, i) + + +class OneDeviceDistributionTestBase(test.TestCase): + """Some tests that should work with any one-device DistributionStrategy.""" + + def _test_all_reduce_sum(self, strategy): + self._test_collective_comms( + strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.])) + + def _test_all_reduce_sum_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_sum, inputs=[4.], expected_grads=[4.]) + + def _test_all_reduce_sum_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_sum, inputs=[4.], expected_grads=[4.]) + + def _test_all_reduce_mean(self, strategy): + self._test_collective_comms( + strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.])) + + def _test_all_reduce_mean_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_mean, inputs=[5.], expected_grads=[5.]) + + def _test_all_reduce_mean_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_mean, inputs=[5.], expected_grads=[5.]) + + def _test_collective_comms(self, strategy, comm_fn, inputs, expected): + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + outputs = self.evaluate( + list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + self.assertAllEqual([expected[0]], outputs[0]) + self.assertAllEqual([expected[1]], outputs[1]) + + def _test_collective_comms_gradients( + self, strategy, comm_fn, inputs, expected_grads): + if context.executing_eagerly(): + self.skipTest("`tf.gradients` is not supported with eager execution.") + + def step(c): + x = constant_op.constant(42.) + y = comm_fn(x) * c + return gradients_impl.gradients(y, [x])[0] + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + def _test_collective_comms_gradient_tape( + self, strategy, comm_fn, inputs, expected_grads): + def step(c): + x = constant_op.constant(42.) + with backprop.GradientTape() as tape: + tape.watch(x) + y = comm_fn(x) * c + return tape.gradient(y, x) + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensors(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + +class TwoDeviceDistributionTestBase(test.TestCase): + """Some tests that should work with any two-device DistributionStrategy.""" + + def _test_all_reduce_sum(self, strategy): + self._test_collective_comms( + strategy, _all_sum, + inputs=([1., 3.], [[39., 2.], [3., 41.]]), + expected=(4., [42., 43.])) + + def _test_all_reduce_sum_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) + + def _test_all_reduce_sum_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.]) + + def _test_all_reduce_mean(self, strategy): + self._test_collective_comms( + strategy, _all_mean, + inputs=([1., 3.], [[39., 2.], [3., 41.]]), + expected=(2., [21., 21.5])) + + def _test_all_reduce_mean_gradients(self, strategy): + self._test_collective_comms_gradients( + strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) + + def _test_all_reduce_mean_gradient_tape(self, strategy): + self._test_collective_comms_gradient_tape( + strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.]) + + def _test_collective_comms(self, strategy, comm_fn, inputs, expected): + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + outputs = self.evaluate( + list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs)))) + self.assertAllEqual([expected[0], expected[0]], outputs[0]) + self.assertAllEqual([expected[1], expected[1]], outputs[1]) + + def _test_collective_comms_gradients( + self, strategy, comm_fn, inputs, expected_grads): + if context.executing_eagerly(): + self.skipTest("`tf.gradients` is not supported with eager execution.") + + def step(c): + x = constant_op.constant(42.) + y = comm_fn(x) * c + return gradients_impl.gradients(y, [x])[0] + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + def _test_collective_comms_gradient_tape( + self, strategy, comm_fn, inputs, expected_grads): + def step(c): + x = constant_op.constant(42.) + with backprop.GradientTape() as tape: + tape.watch(x) + y = comm_fn(x) * c + return tape.gradient(y, x) + + inputs = strategy.make_input_fn_iterator( + lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) + + self.evaluate(inputs.initialize()) + self.assertAllEqual( + expected_grads, + self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs)))) + + +def _all_sum(value): + ctx = ds_context.get_replica_context() + return ctx.all_reduce(reduce_util.ReduceOp.SUM, value) + + +def _all_mean(value): + ctx = ds_context.get_replica_context() + return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 39ed8f7cf10371c0e8dd70e2bdf53f13e8ce8383..69ce1141d8bea835cb959f503647900fba5f6e25 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,21 +21,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import copy -import functools from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import device_assignment as device_assignment_lib +from tensorflow.contrib.tpu.python.tpu import topology from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as session_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import input_lib +from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values +from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op +from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -43,10 +51,31 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -_TPU_INITIALIZE_SYSTEM_COLLECTION = "TPU_STRATEGY_INITIALIZE" +def initialize_tpu_system(cluster_resolver=None): + """Initialize the TPU devices in a separate session and graph. + + Args: + cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, + which provides information about the TPU cluster. + Returns: + The tf.contrib.tpu.Topology object for the topology of the TPU cluster. + """ + if cluster_resolver is None: + cluster_resolver = TPUClusterResolver("") + master = cluster_resolver.master() + + logging.info("Initializing the TPU system.") + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + + with ops.Graph().as_default(): + with session_lib.Session(config=session_config, target=master) as sess: + serialized_topology = sess.run(tpu.initialize_system()) + logging.info("Finished initializing TPU system.") + return topology.Topology(serialized=serialized_topology) def get_tpu_system_metadata(tpu_cluster_resolver): @@ -66,13 +95,14 @@ def get_tpu_system_metadata(tpu_cluster_resolver): # TODO(jhseu): Deduplicate with MirroredStrategy? -def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, - **kwargs): # pylint: disable=g-missing-docstring +def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring + strategy, device_map, logical_device, real_mirrored_creator, + *args, **kwargs): # Figure out what collections this variable should be added to. # We'll add the TPUMirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] + var_collections = kwargs.pop("collections", None) + if var_collections is None: + var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] # TODO(jhseu): Should we have different behavior for different @@ -97,8 +127,11 @@ def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): - index = real_mirrored_creator(devices, *args, **kwargs) - result = values.TPUMirroredVariable(index, index[devices[0]], aggregation) + devices = device_map.logical_to_actual_devices(logical_device) + value_list = real_mirrored_creator(devices, *args, **kwargs) + result = values.TPUMirroredVariable( + strategy, device_map, value_list, aggregation, + logical_device=logical_device) if not context.executing_eagerly(): g = ops.get_default_graph() @@ -108,18 +141,22 @@ def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, # "trainable" to False for next_creator() since that causes functions # like implicit_gradients to skip those variables. if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): + for v in value_list: l.remove(v) - g.add_to_collections(collections, result) + g.add_to_collections(var_collections, result) return result class TPUStrategy(distribute_lib.DistributionStrategy): """TPU distribution strategy implementation.""" - def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None): + def __init__(self, + tpu_cluster_resolver=None, + steps_per_run=None, + device_assignment=None, + **kwargs): """Initializes the TPUStrategy object. Args: @@ -130,45 +167,151 @@ class TPUStrategy(distribute_lib.DistributionStrategy): metrics, summaries etc. This parameter is only used when Distribution Strategy is used with estimator or keras. - num_cores: Number of cores to use on the TPU. If None specified, then - auto-detect the cores and topology of the TPU system. + device_assignment: Optional `tf.contrib.tpu.DeviceAssignment` to specify + the placement of replicas on the TPU cluster. Currently only supports + the usecase of using a single core within a TPU cluster. + **kwargs: Additional experimental flags. Will be removed in future. """ + if len(kwargs) > 1: + raise ValueError("TPUStrategy constructor only takes one experimental " + "flag now") + elif len(kwargs) == 1 and "_disable_training_loop_on_host" not in kwargs: + raise ValueError("TPUStrategy constructor does not support arguments: " + "{}".format(kwargs)) + super(TPUStrategy, self).__init__(TPUExtended( - self, tpu_cluster_resolver, steps_per_run, num_cores)) + self, tpu_cluster_resolver, steps_per_run, device_assignment, + kwargs.get("_disable_training_loop_on_host", False))) @property def steps_per_run(self): """DEPRECATED: use .extended.steps_per_run instead.""" return self._extended.steps_per_run + # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this + # can use the default implementation. + # This implementation runs a single step. It does not use infeed or outfeed. + def experimental_run(self, fn, input_iterator=None): + """See base class.""" + if context.executing_eagerly(): + raise NotImplementedError("Eager mode not supported in TPUStrategy.") + + if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access + raise NotImplementedError( + "`experimental_run` is not compatible with " + "`_disable_training_loop_on_host=True`") + + if input_iterator is None: + inputs = [] + else: + inputs = input_iterator.get_next() + + result = [None] + def replicated_fn(replica_id, inputs): + """Wraps user function to provide replica ID and `Tensor` inputs.""" + with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): + if input_iterator is None: + result[0] = fn() + else: + result[0] = fn(inputs) + return result[0] + + replicate_inputs = [] # By replica. + for i in range(self.num_replicas_in_sync): + replicate_inputs.append( + [constant_op.constant(i, dtype=dtypes.int32), + values.select_replica(i, inputs)]) + + with self.scope(): + replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) + + # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. + replicate_outputs = [ + nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) + for replica_outputs in replicate_outputs] + + device_map = self.extended._device_map # pylint: disable=protected-access + return values.regroup(device_map, replicate_outputs) + class TPUExtended(distribute_lib.DistributionStrategyExtended): """Implementation of TPUStrategy.""" - def __init__(self, container_strategy, tpu_cluster_resolver, steps_per_run, - num_cores=None): + def __init__(self, + container_strategy, + tpu_cluster_resolver=None, + steps_per_run=None, + device_assignment=None, + disable_training_loop_on_host=False): super(TPUExtended, self).__init__(container_strategy) + + if tpu_cluster_resolver is None: + tpu_cluster_resolver = TPUClusterResolver("") + + if steps_per_run is None: + # TODO(frankchn): Warn when we are being used by DS/Keras and this is + # not specified. + steps_per_run = 1 + self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) - # TODO(sourabhbajaj): Change this from num_cores to metadata_override - self._num_cores_override = num_cores + self._device_assignment = device_assignment + self._disable_training_loop_on_host = disable_training_loop_on_host + + # Device assignment is currently only supported for 1 core case. + if self._device_assignment: + assert isinstance(self._device_assignment, + device_assignment_lib.DeviceAssignment) + if self._device_assignment.num_replicas != 1: + raise ValueError("Device assignment is only supported for a single " + "core single replica case currently.") + if self._device_assignment.num_cores_per_replica != 1: + raise ValueError("Device assignment is only supported for a single " + "core single replica case currently.") + if not all(self._device_assignment.core_assignment[0][0] == [0, 0, 0]): + raise ValueError("Device assignment is only supported for a single " + "core single replica case currently.") # TODO(jhseu): Switch to DeviceAssignment to support pods and model # parallelism. - device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices) - if "device:TPU:" in d.name} - self._device_index = values.PerReplica(device_map) + self._device_index = { + d.name: i for i, d in enumerate(self._tpu_metadata.devices) + if "device:TPU:" in d.name + } self._host_device = self.get_host_cpu_device(0) - self._tpu_devices = sorted(device_map.keys()) + self._tpu_devices = tuple(sorted(self._device_index.keys())) # Only create variables for the number of replicas we're running. self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] + self._device_map = values.ReplicaDeviceMap(self._tpu_devices) + + # If the training loop is on the device, we must use the infeed, with input + # on the host. Otherwise, we preload the data onto the TPUs. + if disable_training_loop_on_host: + input_device_map = values.ReplicaDeviceMap(tuple( + self.get_host_cpu_device(hid) for hid in range(self.num_hosts))) + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) + for hid in range(self.num_hosts) + ] + self._input_workers = input_lib.InputWorkers( + input_device_map, worker_devices) + else: + input_worker_devices = collections.OrderedDict() + for tpu_device in self._tpu_devices: + host_device = _get_host_for_device(tpu_device) + input_worker_devices.setdefault(host_device, []) + input_worker_devices[host_device].append(tpu_device) + self._input_workers = input_lib.InputWorkers( + self._device_map, tuple(input_worker_devices.items())) # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run - self._require_static_shapes = True + def _validate_colocate_with_variable(self, colocate_with_variable): + values.validate_colocate_tpu_variable(colocate_with_variable, self) + def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator, input_shapes, iterations): """Create an enqueue op for a single host identified using host_id. @@ -232,27 +375,44 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def _make_dataset_iterator(self, dataset): """Make iterators for each of the TPU hosts.""" - - worker_devices = [ - (self.get_host(hid), [self.get_host_cpu_device(hid)]) - for hid in range(self.num_hosts) - ] - return values.DatasetIterator(dataset, worker_devices, - self._num_replicas_in_sync) - - def _distribute_dataset(self, dataset_fn): - worker_devices = [ - (self.get_host(hid), [self.get_host_cpu_device(hid)]) - for hid in range(self.num_hosts) - ] - return values.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), worker_devices) + return input_lib.DatasetIterator(dataset, self._input_workers, + self._num_replicas_in_sync) + + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + input_contexts = [] + num_workers = self._input_workers.num_workers + for i in range(num_workers): + input_contexts.append(distribute_lib.InputContext( + num_input_pipelines=num_workers, + input_pipeline_id=i, + num_replicas_in_sync=self._num_replicas_in_sync)) + return input_lib.InputFunctionIterator( + input_fn, self._input_workers, input_contexts) + + def _experimental_make_numpy_dataset(self, numpy_input, session): + return numpy_dataset.one_host_numpy_dataset( + numpy_input, numpy_dataset.SingleDevice(self.get_host_cpu_device(0)), + session) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. def _experimental_run_steps_on_iterator( self, fn, multi_worker_iterator, iterations, initial_loop_values=None): + if self._disable_training_loop_on_host: + impl = self._run_steps_on_iterator_with_device_loop + else: + impl = self._run_steps_on_iterator_with_host_loop + + return impl( + fn=fn, multi_worker_iterator=multi_worker_iterator, + iterations=iterations, initial_loop_values=initial_loop_values) + + def _run_steps_on_iterator_with_host_loop( + self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any(not s.is_fully_defined() for s in shapes): @@ -260,29 +420,16 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " "dataset.batch(..., drop_remainder=True).") - types = nest.flatten(multi_worker_iterator.output_types) - - enqueue_ops = [ - self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, - iterations) - for host_id in range(self.num_hosts)] - - def dequeue_fn(): - dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) - return nest.pack_sequence_as(output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) - ctx = values.MultiStepContext() + ctx = input_lib.MultiStepContext() - def run_fn(): + def run_fn(inputs): """Single step on the TPU device.""" - fn_inputs = dequeue_fn() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, fn_inputs) + fn_result = fn(ctx, inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): @@ -302,7 +449,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args - replicate_inputs = [[]] * self._num_replicas_in_sync + + per_replica_inputs = multi_worker_iterator.get_next() + replicate_inputs = [] + for replica_id in range(self._num_replicas_in_sync): + select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop + replicate_inputs.append((nest.map_structure( + select_replica, per_replica_inputs),)) + replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We @@ -314,8 +468,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): return replicate_outputs - # TODO(sourabhbajaj): The input to while loop should be based on the output - # type of the step_fn + # TODO(sourabhbajaj): The input to while loop should be based on the + # output type of the step_fn assert isinstance(initial_loop_values, list) initial_loop_values = initial_loop_values * self._num_replicas_in_sync @@ -325,7 +479,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): initial_loop_values) del self._outer_control_flow_context - ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) + ctx.run_op = control_flow_ops.group(replicate_outputs) if isinstance(replicate_outputs, list): # Filter out any ops from the outputs, typically this would be the case @@ -350,23 +504,80 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # no tensors returned. last_step_tensor_outputs = [] - # Convert replicate_outputs to the original dict structure of - # last_step_outputs. - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access - output = last_step_tensor_outputs_dict[name] - # For outputs that have already been reduced, take the first value - # from the list as each value should be the same. Else return the full - # list of values. - # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica - # value. - if reduce_op is not None: - # TODO(priyag): Should this return the element or a list with 1 element - last_step_tensor_outputs_dict[name] = output[0] - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + _set_last_step_outputs(ctx, last_step_tensor_outputs) + return ctx + + def _run_steps_on_iterator_with_device_loop( + self, fn, multi_worker_iterator, iterations, initial_loop_values=None): + output_shapes = multi_worker_iterator.output_shapes + shapes = nest.flatten(output_shapes) + if any(not s.is_fully_defined() for s in shapes): + raise ValueError( + "TPU currently requires fully defined shapes. Either use " + "set_shape() on the input tensors or use " + "dataset.batch(..., drop_remainder=True).") + types = nest.flatten(multi_worker_iterator.output_types) + + enqueue_ops = [ + self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, + iterations) + for host_id in range(self.num_hosts)] + + def dequeue_fn(): + dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) + return nest.pack_sequence_as(output_shapes, dequeued) + + # Wrap `fn` for repeat. + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + ctx = input_lib.MultiStepContext() + + def run_fn(*args, **kwargs): + """Single step on the TPU device.""" + del args, kwargs + fn_result = fn(ctx, dequeue_fn()) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + if flat_last_step_outputs: + with ops.control_dependencies([fn_result]): + return [array_ops.identity(f) for f in flat_last_step_outputs] + else: + return fn_result + + def iterate_on_tpu(): + return training_loop.repeat(iterations, run_fn, initial_loop_values) + + # We capture the control_flow_context at this point, before we run `fn` + # inside a while_loop and TPU replicate context. This is useful in cases + # where we might need to exit these contexts and get back to the outer + # context to do some things, for e.g. create an op which should be + # evaluated only once at the end of the loop on the host. One such usage + # is in creating metrics' value op. + self._outer_control_flow_context = ( + ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access + + replicate_inputs = [[]] * self._num_replicas_in_sync + replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + del self._outer_control_flow_context + ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) + + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [x for x in replicate_outputs + if not isinstance(x, ops.Operation)] + + # Outputs are currently of the structure (grouped by device) + # [[output0_device0, output1_device0, output2_device0], + # [output0_device1, output1_device1, output2_device1]] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + last_step_tensor_outputs = [list(x) for x in + zip(*last_step_tensor_outputs)] + + _set_last_step_outputs(ctx, last_step_tensor_outputs) return ctx def _call_for_each_replica(self, fn, args, kwargs): @@ -375,44 +586,34 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): with _TPUReplicaContext(self._container_strategy()): return fn(*args, **kwargs) - def _initialize(self): - if context.executing_eagerly(): - # TODO(priyag): Add appopriate call here when eager is supported for TPUs. - raise NotImplementedError("Eager mode not supported in TPUStrategy.") - else: - # TODO(jhseu): We need this hack because DistributionStrategies must be - # pickleable for copy.deepcopy(). Remove when initialize_system goes away. - graph = ops.get_default_graph() - tpu_init = graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) - if tpu_init: - return tpu_init - graph.add_to_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION, - tpu.initialize_system()) - return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) - - def _finalize(self): - if context.executing_eagerly(): - # TODO(priyag): Add appopriate call here when eager is supported for TPUs. - raise NotImplementedError("Eager mode not supported in TPUStrategy.") - else: - return [tpu.shutdown_system()] + def _experimental_initialize_system(self): + """Experimental method added to be used by Estimator. - def _get_devices_from(self, colocate_with=None): - # TODO(jhseu): Change this when we support model parallelism. - return self._tpu_devices + This is a private method only to be used by Estimator. Other frameworks + should directly be calling `tf.contrib.distribute.initialize_tpu_system` + """ + initialize_tpu_system(self._tpu_cluster_resolver) def _create_variable(self, next_creator, *args, **kwargs): """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" colocate_with = kwargs.pop("colocate_with", None) - devices = self._get_devices_from(colocate_with) + if colocate_with is None: + device_map = self._device_map + logical_device = 0 # TODO(josh11b): Get logical device from scope here. + elif isinstance(colocate_with, numpy_dataset.SingleDevice): + with ops.device(colocate_with.device): + return next_creator(*args, **kwargs) + else: + device_map = colocate_with.device_map + logical_device = colocate_with.logical_device def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring - index = {} + value_list = [] for i, d in enumerate(devices): with ops.device(d): if i > 0: # Give replicas meaningful distinct names: - var0name = index[devices[0]].name.split(":")[0] + var0name = value_list[0].name.split(":")[0] # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. @@ -420,20 +621,21 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # Initialize replicas with the same value: if context.executing_eagerly(): kwargs["initial_value"] = array_ops.identity( - index[devices[0]].value()) + value_list[0].value()) else: def initial_value_fn(device=d): with ops.device(device): - return array_ops.identity(index[devices[0]].initial_value) + return array_ops.identity(value_list[0].initial_value) kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.TPUMirroredVariable) - index[d] = v - return index + value_list.append(v) + return value_list - return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args, - **kwargs) + return _create_tpu_mirrored_variable( + self._container_strategy(), device_map, logical_device, + _real_mirrored_creator, *args, **kwargs) def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access @@ -445,6 +647,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) + if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerReplica or + # Mirrored values. For example, the same value could be present on all + # replicas in which case `value` would be a single value or value could + # be 0. + return cross_device_ops_lib.reduce_non_distributed_value( + reduce_op, self._device_map, value, destinations) + # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. @@ -466,19 +676,19 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): if group: return fn(var, *args, **kwargs) else: - return [fn(var, *args, **kwargs)] + return (fn(var, *args, **kwargs),) # Otherwise, we revert to MirroredStrategy behavior and update each variable # directly. - updates = {} - for d, v in var._index.items(): # pylint: disable=protected-access - name = "update_%d" % self._device_index.get(d) + updates = [] + for i, (d, v) in enumerate(zip(var.devices, var.values)): + name = "update_%d" % i with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): # If args and kwargs are not mirrored, the value is returned as is. - updates[d] = fn(v, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, group) + updates.append(fn(v, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs))) + return values.update_regroup(self, self._device_map, updates, group) def read_var(self, var): assert isinstance(var, values.TPUMirroredVariable) @@ -487,13 +697,18 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): def _unwrap(self, val): if isinstance(val, values.DistributedValues): # Return in a deterministic order. - return [val.get(device=d) for d in sorted(val.devices)] + return tuple(val.get(device=d) for d in sorted(val.devices)) elif isinstance(val, list): # TODO(josh11b): We need to remove this case; per device values should # be represented using a PerReplica wrapper instead of a list with # one entry per device. - return val - return [val] + return tuple(val) + elif isinstance(val, values.TPUMirroredVariable): + # pylint: disable=protected-access + if values._enclosing_tpu_context() is not None: + return (val,) + return val.values + return (val,) def value_container(self, value): return value @@ -504,15 +719,34 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): @property def num_hosts(self): - return self._tpu_metadata.num_hosts + if self._device_assignment is None: + return self._tpu_metadata.num_hosts + + return len(set([self._device_assignment.host_device(r) + for r in range(self._device_assignment.num_replicas)])) @property def num_replicas_per_host(self): - return self._tpu_metadata.num_of_cores_per_host + if self._device_assignment is None: + return self._tpu_metadata.num_of_cores_per_host + + # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed + # as the computation of num_replicas_per_host is not a constant + # when using device_assignment. This is a temporary workaround to support + # StatefulRNN as everything is 1 in that case. + # This method needs to take host_id as input for correct computation. + max_models_per_host = (self._tpu_metadata.num_of_cores_per_host // + self._device_assignment.num_cores_per_replica) + models_per_host = min(self._device_assignment.num_replicas, + max_models_per_host) + return models_per_host * self._device_assignment.num_cores_per_replica @property def _num_replicas_in_sync(self): - return self._num_cores_override or self._tpu_metadata.num_cores + if self._device_assignment is None: + return self._tpu_metadata.num_cores + return (self._device_assignment.num_replicas * + self._device_assignment.num_cores_per_replica) @property def experimental_between_graph(self): @@ -580,23 +814,62 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # TODO(priyag): Delete this once all strategies use global batch size. @property def _global_batch_size(self): + """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. + + `make_input_fn_iterator` assumes per-replica batching. + + Returns: + Boolean. + """ return True class _TPUReplicaContext(distribute_lib.ReplicaContext): """Replication Context class for TPU Strategy.""" - # TODO(sourabhbajaj): Call for each tower should be updating this. - def __init__(self, distribution_strategy): + # TODO(sourabhbajaj): Call for each replica should be updating this. + # TODO(b/118385803): Always properly initialize replica_id. + def __init__(self, strategy, replica_id_in_sync_group=None): + if replica_id_in_sync_group is None: + replica_id_in_sync_group = constant_op.constant(0, dtypes.int32) distribute_lib.ReplicaContext.__init__( - self, - distribution_strategy, - # TODO(b/118385803): properly initialize replica_id, instead of always 0 - replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) + self, strategy, replica_id_in_sync_group=replica_id_in_sync_group) @property def devices(self): distribute_lib.require_replica_context(self) - ds = self._distribution_strategy + ds = self._strategy replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) - return [ds.extended.worker_devices[replica_id]] + + if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. + # TODO(cjfj): Return other devices when model parallelism is supported. + return (tpu.core(0),) + else: + return (ds.extended.worker_devices[replica_id],) + + +def _get_host_for_device(device): + spec = tf_device.DeviceSpec.from_string(device) + return tf_device.DeviceSpec( + job=spec.job, replica=spec.replica, task=spec.task, + device_type="CPU", device_index=0).to_string() + + +def _set_last_step_outputs(ctx, last_step_tensor_outputs): + """Sets the last step outputs on the given context.""" + # Convert replicate_outputs to the original dict structure of + # last_step_outputs. + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been reduced, take the first value + # from the list as each value should be the same. Else return the full + # list of values. + # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica + # value. + if reduce_op is not None: + # TODO(priyag): Should this return the element or a list with 1 element + last_step_tensor_outputs_dict[name] = output[0] + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 538b859f3d1ece55b460f6dbf8f01540a6013381..51c58b0b2f3dc2ab63e22718825a471b8657f892 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -22,27 +22,20 @@ import os from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import device_util -from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import saver as saver_lib -from tensorflow.python.util import nest class DistributedValuesTest(test.TestCase): @@ -51,7 +44,8 @@ class DistributedValuesTest(test.TestCase): with ops.device("/device:CPU:0"): one = constant_op.constant(1) two = constant_op.constant(2) - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) with self.assertRaises(ValueError): @@ -63,24 +57,26 @@ class DistributedValuesTest(test.TestCase): ops.device("/device:CPU:0"): one = constant_op.constant(1) two = constant_op.constant(2) - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) with self.assertRaises(ValueError): self.assertIsNone(v.get("/device:GPU:2")) def testCanonicalization(self): - canonical_cpu = ["/job:localhost/replica:0/task:0/device:CPU:0"] - v = values.DistributedValues({"": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) - v = values.DistributedValues({"/device:CPU:0": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) - v = values.DistributedValues({"/cpu:0": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) - v = values.DistributedValues({"/CPU:0": 42}) - self.assertEqual(canonical_cpu, list(v._index.keys())) + canonical_cpu = ("/job:localhost/replica:0/task:0/device:CPU:0",) + v = values.DistributedValues(values.SingleDeviceMap(""), (42,)) + self.assertEqual(canonical_cpu, v.devices) + v = values.DistributedValues(values.SingleDeviceMap("/device:CPU:0"), (42,)) + self.assertEqual(canonical_cpu, v.devices) + v = values.DistributedValues(values.SingleDeviceMap("/cpu:0"), (42,)) + self.assertEqual(canonical_cpu, v.devices) + v = values.DistributedValues(values.SingleDeviceMap("/CPU:0"), (42,)) + self.assertEqual(canonical_cpu, v.devices) with self.assertRaises(AssertionError): - v = values.DistributedValues({"/device:cpu:0": 42}) + v = values.DistributedValues( + values.SingleDeviceMap("/device:cpu:0"), (42,)) def testIsTensorLike(self): with context.graph_mode(), \ @@ -88,7 +84,8 @@ class DistributedValuesTest(test.TestCase): ops.device("/device:CPU:0"): one = constant_op.constant(1) two = constant_op.constant(2) - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) self.assertTrue(v.is_tensor_like) @@ -100,7 +97,8 @@ class DistributedValuesTest(test.TestCase): ops.device("/device:CPU:0"): one = constant_op.constant(1) two = 2.0 - v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedValues(device_map, (one, two)) self.assertEqual(two, v.get("/device:GPU:0")) self.assertEqual(one, v.get()) self.assertFalse(v.is_tensor_like) @@ -118,8 +116,8 @@ class DistributedDelegateTest(test.TestCase): def __init__(self, x): self.x = x - v = values.DistributedDelegate( - {"/device:CPU:0": Foo(7), "/device:GPU:0": Foo(8)}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedDelegate(device_map, (Foo(7), Foo(8))) self.assertEqual(7, v.x) with self.assertRaises(AttributeError): _ = v.y @@ -127,7 +125,8 @@ class DistributedDelegateTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testOperatorOverride(self): with ops.device("/device:CPU:0"): - v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8}) + device_map = values.ReplicaDeviceMap(("/device:CPU:0", "/device:GPU:0")) + v = values.DistributedDelegate(device_map, (7, 8)) # v should act like int(7). self.assertEqual(8, v + 1) self.assertEqual(10, 3 + v) @@ -178,16 +177,15 @@ def _nested_value(d): def _make_mirrored(): v = [] - index = {} devices = ["/device:GPU:0", "/device:CPU:0"] for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): with ops.device(d): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) - index[d] = v[-1] - mirrored = values.MirroredVariable(index, v[0], + device_map = values.ReplicaDeviceMap(devices) + mirrored = values.MirroredVariable(None, device_map, v, variable_scope.VariableAggregation.SUM) - return v, devices, mirrored + return v, device_map, mirrored class RegroupAndSelectDeviceTest(test.TestCase): @@ -204,8 +202,9 @@ class RegroupAndSelectDeviceTest(test.TestCase): self.assertEqual(expected[i], result.get(_device_str(i))) def testNested(self): - result = values.regroup({_device_str(0): _nested_value("1"), - _device_str(1): _nested_value("2")}) + device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) + result = values.regroup(device_map, + (_nested_value("1"), _nested_value("2"))) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) self._is_per_replica(result[0], ["a1", "a2"]) @@ -221,11 +220,11 @@ class RegroupAndSelectDeviceTest(test.TestCase): self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) - # Also test that we can undo the merge using select_device() + # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), - values.select_device(_device_str(0), result)) + values.select_replica(0, result)) self.assertEqual(_nested_value("2"), - values.select_device(_device_str(1), result)) + values.select_replica(1, result)) # select_device_mirrored() should fail due to non-mirrored values with self.assertRaises(TypeError): values.select_device_mirrored(_device_str(0), result) @@ -235,8 +234,9 @@ class RegroupAndSelectDeviceTest(test.TestCase): def testWrapClass(self): # Normally a mirrored value would be the same across devices, but # for a test it is convenient to be able to tell the values apart. - result = values.regroup({_device_str(0): _nested_value("1"), - _device_str(1): _nested_value("2")}, + device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) + result = values.regroup(device_map, + (_nested_value("1"), _nested_value("2")), values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) @@ -253,11 +253,11 @@ class RegroupAndSelectDeviceTest(test.TestCase): self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) - # Also test that we can undo the merge using select_device() + # Also test that we can undo the merge using select_replica() self.assertEqual(_nested_value("1"), - values.select_device(_device_str(0), result)) + values.select_replica(0, result)) self.assertEqual(_nested_value("2"), - values.select_device(_device_str(1), result)) + values.select_replica(1, result)) # Values are marked as mirrored, so select_device_mirrored() is allowed. self.assertEqual(_nested_value("1"), values.select_device_mirrored(_device_str(0), result)) @@ -267,63 +267,66 @@ class RegroupAndSelectDeviceTest(test.TestCase): def testMirroredContainer(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - v, devices, mirrored = _make_mirrored() - result = values.regroup(dict(zip(devices, v))) + v, device_map, mirrored = _make_mirrored() + result = values.regroup(device_map, v) self.assertIs(mirrored, result) def testSameId(self): foo = object() - result = values.regroup({_device_str(0): ("a", foo), - _device_str(1): ("b", foo)}) + device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1))) + result = values.regroup(device_map, (("a", foo), ("b", foo))) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) - # Test select_device(), should undo the merge done by regroup(). - result_0 = values.select_device(_device_str(0), result) + # Test select_replica(), should undo the merge done by regroup(). + result_0 = values.select_replica(0, result) self.assertIsInstance(result_0, tuple) self.assertEqual(2, len(result_0)) self.assertEqual("a", result_0[0]) self.assertIs(foo, result_0[1]) - result_1 = values.select_device(_device_str(1), result) + result_1 = values.select_replica(1, result) self.assertIsInstance(result_1, tuple) self.assertEqual(2, len(result_1)) self.assertEqual("b", result_1[0]) self.assertIs(foo, result_1[1]) def testOneDevice(self): - result = values.regroup({_device_str(0): _nested_value("1")}) - # On one device regroup() and select_device() are basically identity. + device_map = values.ReplicaDeviceMap((_device_str(0),)) + result = values.regroup(device_map, (_nested_value("1"),)) + # On one device regroup() and select_replica() are basically identity. self.assertEqual(_nested_value("1"), result) self.assertEqual(_nested_value("1"), - values.select_device(_device_str(0), result)) + values.select_replica(0, result)) # The one exception has to do with MirroredVariables. d = "/device:CPU:0" with ops.device(d): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) - index = {d: v} - mirrored = values.MirroredVariable(index, v, + device_map = values.ReplicaDeviceMap((d,)) + mirrored = values.MirroredVariable(None, device_map, (v,), variable_scope.VariableAggregation.SUM) - result = values.regroup(index) + result = values.regroup(device_map, (v,)) self.assertIs(mirrored, result) def testNamedTupleEstimatorSpec(self): with context.graph_mode(), ops.Graph().as_default(): - created_estimator_specs = {} - to_regroup = {} + devices = [] + created_estimator_specs = [] for device_id in range(3): spec = model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.TRAIN, loss=constant_op.constant(device_id / 2), train_op=array_ops.identity(constant_op.constant(device_id))) - created_estimator_specs[device_id] = spec - to_regroup[_device_str(device_id)] = spec + devices.append(_device_str(device_id)) + created_estimator_specs.append(spec) - merged_estimator_spec = values.regroup(to_regroup) + device_map = values.ReplicaDeviceMap(devices) + merged_estimator_spec = values.regroup( + device_map, created_estimator_specs) self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) @@ -337,415 +340,10 @@ class RegroupAndSelectDeviceTest(test.TestCase): # Scaffold is populated by `EstimatorSpec.__new__`. self.assertEqual(created_estimator_specs[device_id].scaffold, merged_estimator_spec.scaffold.get(d)) - # Also test that we can undo the merge using select_device() + # Also test that we can undo the merge using select_replica() self.assertEqual(created_estimator_specs[device_id], - values.select_device(_device_str(device_id), - merged_estimator_spec)) - - -class PerReplicaDatasetTest(test.TestCase): - - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - - def _test_iterator(self, devices, dataset, expected_values): - per_replica_dataset = values.PerReplicaDataset(dataset, devices) - if context.executing_eagerly(): - iterator = per_replica_dataset.make_one_shot_iterator() - else: - iterator = per_replica_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = self.evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - self.evaluate([ - values.select_device(d, next_element) for d in devices]) - - @test_util.run_in_graph_and_eager_modes - def testOneDevice(self): - devices = ["/device:CPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleDevices(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testTupleDataset(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(devices, dataset, expected_values) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testUnevenDatasetBatches(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - devices = ["/device:CPU:0", "/device:GPU:0"] - dataset = dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(devices, dataset, expected_values) - - def testInitializableIterator(self): - with context.graph_mode(): - devices = ["/device:CPU:0"] - # Using random input since that is only allowed with initializable - # iterator. - dataset = dataset_ops.Dataset.from_tensor_slices( - random_ops.random_uniform((10,))) - - per_replica_dataset = values.PerReplicaDataset(dataset, devices) - iterator = per_replica_dataset.make_initializable_iterator() - - self.evaluate(iterator.initializer) - next_element = iterator.get_next() - for _ in range(10): - self.evaluate(next_element) - - # Should fail after the input is finished. - with self.assertRaises(errors.OutOfRangeError): - self.evaluate(next_element) - - # After re-initializing the iterator, should be able to iterate again. - self.evaluate(iterator.initializer) - for _ in range(10): - self.evaluate(next_element) - - -class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): - - def _test_iterator(self, sess, iterator, devices, expected_values): - next_element = iterator.get_next() - for device in devices: - v = values.select_device(device, next_element) - # The `v` here can be a tuple. - for element in nest.flatten(v): - self.assertTrue(element.device in device) - - for expected_value in expected_values: - actual = sess.run( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, actual) - - with self.assertRaises(errors.OutOfRangeError): - sess.run([values.select_device(d, next_element) for d in devices]) - - def _test_dataset(self, dataset_fn, worker_devices, devices, - expected_values, auto_shard=True): - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_devices, auto_shard=auto_shard) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - with self.cached_session() as sess: - sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, expected_values) - - def _cpu_devices(self): - worker_devices = [ - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])] - devices = [ - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def _cpu_and_one_gpu_devices(self): - worker_devices = [ - ("/job:worker/replica:0/task:0", [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - ]), - ("/job:worker/replica:0/task:1", [ - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ]) - ] - devices = [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0", - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ] - return worker_devices, devices - - def testDataDistributionOneDevicePerWorker(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) - - def testDataDistributionNoAutoShard(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 0], [1, 1], [2, 2], [3, 3]], - auto_shard=False) - - def testDataDistributionTwoDevicePerWorker(self): - if context.num_gpus() < 1: - self.skipTest("A GPU is not available for this test.") - worker_devices, devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_devices, devices, - [[0, 2, 1, 3], [4, 6, 5, 7]]) - - def testTupleDataset(self): - worker_devices, devices = self._cpu_devices() - - with context.graph_mode(): - - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(8) - dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [ - [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2) - ] - self._test_dataset(dataset_fn, worker_devices, devices, - expected_values) - - def testInitializableIterator(self): - worker_devices, devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(8) - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_devices, auto_shard=True) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - - sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) - - # After re-initializing the iterator, should be able to iterate again. - sess.run(multi_worker_iterator.initializer) - self._test_iterator(sess, multi_worker_iterator, devices, - [[0, 1], [2, 3], [4, 5], [6, 7]]) - - def testValueErrorForIterator(self): - # Incompatiable arguments. - with self.assertRaises(ValueError): - values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"}) - - # Test duplicated devices under same worker. - worker_devices, _ = self._cpu_devices() - worker_devices[0][1].append("/job:worker/replica:0/task:0/device:CPU:0") - with context.graph_mode(): - dataset_fn = lambda: dataset_ops.Dataset.range(8) - multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_devices, auto_shard=True) - multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - with self.assertRaises(ValueError): - multi_worker_iterator.get_next() - - -class InputIteratorTestBase(test.TestCase): - - def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, split_batch_by=None): - devices = nest.flatten([ds for _, ds in worker_device_pairs]) - - if input_type == "input_fn": - input_contexts = [ - distribute_lib.InputContext() for _ in worker_device_pairs] - input_fn = lambda _: dataset_fn() - iterator = values.InputFunctionIterator(input_fn, worker_device_pairs, - input_contexts) - else: - iterator = values.DatasetIterator(dataset_fn(), worker_device_pairs, - split_batch_by) - - evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) - - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertAllEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - evaluate([values.select_device(d, next_element) for d in devices]) - - # After re-initializing the iterator, should be able to iterate again. - evaluate(control_flow_ops.group(iterator.initialize())) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertAllEqual(expected_value, computed_value) - - -class InputIteratorSingleWorkerTest(InputIteratorTestBase, - parameterized.TestCase): - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"])) - def testOneDeviceCPU(self, input_type): - worker_device_pairs = [("", ["/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i] for i in range(10)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesOneGPUOneCPU(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(10) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTupleDataset(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testUnevenDatasetBatches(self, input_type): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda: dataset_ops.Dataset.range(11) - - expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - input_type=["dataset"], - split_batch_by=[None, 2], - required_gpus=1)) - def testBatchSplitting(self, input_type, split_batch_by): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - batch_size = 10 - dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) - - updated_batch_size = ( - batch_size // split_batch_by if split_batch_by else batch_size) - expected_values = [[range(i, i+updated_batch_size), - range(i+updated_batch_size, i+2*updated_batch_size)] - for i in range(0, 100, updated_batch_size*2)] - - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, - split_batch_by=split_batch_by) - - -class InputIteratorMultiWorkerTest( - multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, - parameterized.TestCase): - - def _cpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])] - - def _cpu_and_one_gpu_devices(self): - return [ - ("/job:worker/replica:0/task:0", [ - "/job:worker/replica:0/task:0/device:GPU:0", - "/job:worker/replica:0/task:0/device:CPU:0" - ]), - ("/job:worker/replica:0/task:1", [ - "/job:worker/replica:0/task:1/device:GPU:0", - "/job:worker/replica:0/task:1/device:CPU:0" - ]) - ] - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testOneDevicePerWorker(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 0], [1, 1], [2, 2], [3, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"], - required_gpus=1)) - def testTwoDevicesPerWorker(self, input_type): - worker_devices = self._cpu_and_one_gpu_devices() - with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda: dataset_ops.Dataset.range(4) - self._test_iterator(input_type, dataset_fn, worker_devices, - [[0, 1, 0, 1], [2, 3, 2, 3]], sess) - - @combinations.generate(combinations.combine( - mode=["graph"], - input_type=["input_fn", "dataset"])) - def testTupleDataset(self, input_type): - worker_devices = self._cpu_devices() - with context.graph_mode(), self.cached_session() as sess: - def dataset_fn(): - dataset1 = dataset_ops.Dataset.range(4) - dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) - - expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess) + values.select_replica(device_id, + merged_estimator_spec)) class MirroredVariableTest(test.TestCase, parameterized.TestCase): @@ -768,8 +366,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) - index = {"/job:foo/device:CPU:0": v} - mirrored = values.MirroredVariable(index, v, + device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",)) + mirrored = values.MirroredVariable(None, device_map, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(v.name, mirrored.name) @@ -797,7 +395,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): self.skipTest("A GPU is not available for this test in eager mode.") with self.cached_session(config=self.config) as sess: - v, devices, mirrored = _make_mirrored() + v, device_map, mirrored = _make_mirrored() + devices = device_map.all_devices # Overwrite the initial values. self._assign_mirrored(devices, v, [3., 4.]) @@ -815,7 +414,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): def _save_mirrored(self): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: - v, devices, mirrored = _make_mirrored() + v, device_map, mirrored = _make_mirrored() + devices = device_map.all_devices # Overwrite the initial values. self._assign_mirrored(devices, v, [3., 4.]) @@ -860,7 +460,8 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): def _restore_mirrored(self, save_path): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: - v, devices, mirrored = _make_mirrored() + v, device_map, mirrored = _make_mirrored() + devices = device_map.all_devices # Overwrite the initial values. self._assign_mirrored(devices, v, [7., 8.]) @@ -904,25 +505,24 @@ class MirroredVariableTest(test.TestCase, parameterized.TestCase): with ops.device("/device:GPU:0"): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) - mirrored = values.MirroredVariable({ - "/device:GPU:0": v - }, v, variable_scope.VariableAggregation.MEAN) + mirrored = values.MirroredVariable( + distribution, values.ReplicaDeviceMap(("/device:GPU:0",)), (v,), + variable_scope.VariableAggregation.MEAN) sess.run(variables_lib.global_variables_initializer()) sess.run({"complicated": mirrored}) -_devices = ["/device:GPU:0", "/device:CPU:0"] +_devices = ("/device:GPU:0", "/device:CPU:0") -def _make_replica_local(method): +def _make_replica_local(method, strategy=None): + device_map = values.ReplicaDeviceMap(_devices) v = [] - index = {} for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]): with ops.device(d): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) - index[d] = v[-1] - replica_local = values.ReplicaLocalVariable(index, v[0], method) + replica_local = values.ReplicaLocalVariable(strategy, device_map, v, method) return v, replica_local @@ -948,9 +548,9 @@ class ReplicaLocalVariablePropertiesTest(test.TestCase): def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) - index = {"/job:foo/device:CPU:0": v} + device_map = values.ReplicaDeviceMap(("/job:foo/device:CPU:0",)) replica_local = values.ReplicaLocalVariable( - index, v, variable_scope.VariableAggregation.MEAN) + None, device_map, (v,), variable_scope.VariableAggregation.MEAN) self.assertEqual(v.name, replica_local.name) self.assertEqual(v.dtype, replica_local.dtype) @@ -997,7 +597,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): with self.cached_session() as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.SUM) + variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1020,7 +620,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): with self.cached_session() as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1040,7 +640,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) @@ -1056,7 +656,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def _save_replica_local_sum(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: - v, replica_local = _make_replica_local("sum") + v, replica_local = _make_replica_local("sum", distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [1.5, 2.]) @@ -1103,7 +703,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.MEAN) + variable_scope.VariableAggregation.MEAN, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) @@ -1118,7 +718,7 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( - variable_scope.VariableAggregation.SUM) + variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 3079175015a9aee1625404902070df8f13b2089c..c2300286d3be4bb757dac588623c47044a1a9db5 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -822,7 +822,7 @@ cuda_py_test( cuda_py_test( name = "affine_test", - size = "large", + size = "medium", srcs = ["python/kernel_tests/bijectors/affine_test.py"], additional_deps = [ ":bijectors_py", @@ -837,7 +837,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], - shard_count = 5, + shard_count = 10, tags = ["noasan"], # times out b/63678675 ) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 452628257ea96713453bf2aa32b5baa9d6d0cb86..1006dfac49f36baa7cf5136f6f2982e3fd965298 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -249,9 +249,9 @@ class InverseGamma(distribution.Distribution): `self.allow_nan_stats` is `False`, an exception will be raised rather than returning `NaN`.""") def _variance(self): - var = (math_ops.square(self.rate) - / math_ops.square(self.concentration - 1.) - / (self.concentration - 2.)) + var = ( + math_ops.square(self.rate) / math_ops.squared_difference( + self.concentration, 1.) / (self.concentration - 2.)) if self.allow_nan_stats: nan = array_ops.fill( self.batch_shape_tensor(), diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index 978e627d6638ddeea9df288d389354f0ac53d115..19e99e03803e7f4cdfdb023feb04daaba68eceed 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -300,7 +300,7 @@ def percentile(x, raise ValueError("Argument 'interpolation' must be in %s. Found %s" % (allowed_interpolations, interpolation)) - with ops.name_scope(name, [x, q]): + with ops.name_scope(name, values=[x, q]): x = ops.convert_to_tensor(x, name="x") # Double is needed here and below, else we get the wrong index if the array # is huge along axis. diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 77052a75a70bec1162feb2b126d247924b3a2e36..8966a9befcd3db4a3f397b319e80f37f84ad236b 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -15,7 +15,6 @@ py_library( ":metrics", ":network", ":parameter_server", - ":remote", ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", @@ -31,6 +30,7 @@ py_library( "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:execution_callbacks", "//tensorflow/python/eager:function", + "//tensorflow/python/eager:remote", ], ) @@ -238,24 +238,12 @@ py_test( ], ) -py_library( - name = "remote", - srcs = ["remote.py"], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:platform", - "//tensorflow/python/eager:context", - ], -) - cuda_py_test( name = "remote_test", srcs = ["remote_test.py"], additional_deps = [ ":parameter_server", - ":remote", + "//tensorflow/python/eager:remote", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/python:array_ops", "//tensorflow/python:client", diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 257d02057ae0d280074559aa9e97725bf5cc3fd0..78ab155896cfeda4dd259a8529f4b1f77a12cf0b 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -200,13 +200,6 @@ class IteratorTest(test.TestCase): y = math_ops.add(x, x) self.assertAllEqual([0., 2.], y.numpy()) - def testGpuDefinedDataset(self): - with ops.device(test.gpu_device_name()): - ds = Dataset.from_tensors([0., 1.]) - for x in ds: - y = math_ops.add(x, x) - self.assertAllEqual([0., 2.], y.numpy()) - def testOverrideThreadPool(self): def get_thread_id(_): diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index 7949a3f6da293abdd85512209242bae76ab4d816..51443d24829bdc31a41813e0ff50ad7102422112 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -22,6 +22,7 @@ import six from tensorflow.contrib.eager.python import datasets from tensorflow.contrib.eager.python import metrics +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import errors_impl @@ -164,8 +165,8 @@ class Evaluator(object): self.__call__(example, *args, **kwargs) return self.all_metric_results(summary_logdir) # Graph construction - call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args, - **kwargs) + call_op = self.__call__( + dataset_ops.make_one_shot_iterator(dataset).get_next(), *args, **kwargs) init_op = self.init_variables() results_op = self.all_metric_results(summary_logdir) return (init_op, call_op, results_op) diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 97c299a911c9180bf69faa0fa46527e80eada790..3e0881754c750f4d36e2e4dd8b80835b031c658c 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -6,16 +6,16 @@ package(default_visibility = ["//tensorflow:internal"]) py_library( name = "examples_pip", deps = [ - "//tensorflow/contrib/eager/python/examples/densenet", - "//tensorflow/contrib/eager/python/examples/gan:mnist", + "//tensorflow/contrib/eager/python/examples/densenet:densenet_lib", + "//tensorflow/contrib/eager/python/examples/gan:mnist_lib", "//tensorflow/contrib/eager/python/examples/l2hmc", "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", - "//tensorflow/contrib/eager/python/examples/linear_regression", + "//tensorflow/contrib/eager/python/examples/linear_regression:linear_regression_lib", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/revnet", "//tensorflow/contrib/eager/python/examples/revnet:config", - "//tensorflow/contrib/eager/python/examples/rnn_colorbot", - "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/rnn_colorbot:rnn_colorbot_lib", + "//tensorflow/contrib/eager/python/examples/rnn_ptb:rnn_ptb_lib", "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD index 2dc196f550a10367066730f6f042c4ed69533ec3..fbb5daf230bb79f08a3d071062ddc0e8507ab324 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/BUILD +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -3,11 +3,19 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "densenet", srcs = ["densenet.py"], srcs_version = "PY2AND3", + deps = [":densenet_lib"], +) + +py_library( + name = "densenet_lib", + srcs = ["densenet.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -16,33 +24,37 @@ py_binary( cuda_py_test( name = "densenet_test", - size = "large", + size = "medium", srcs = ["densenet_test.py"], additional_deps = [ ":densenet", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", "optonly", + "oss_serial", ], ) cuda_py_test( name = "densenet_graph_test", - size = "large", + size = "medium", srcs = ["densenet_graph_test.py"], additional_deps = [ ":densenet", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", "noasan", "nomsan", "notsan", "optonly", + "oss_serial", ], ) diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py index 4b3cb624bc947a1d1956eff6accb6d4da3bf3b87..24f6b007b526b29157011f3b1e9abdbd50bacc8e 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py @@ -119,7 +119,8 @@ class DensenetBenchmark(tf.test.Benchmark): with tf.Graph().as_default(): np_images, np_labels = random_batch(batch_size) dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() - (images, labels) = dataset.make_one_shot_iterator().get_next() + (images, labels) = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, self.output_classes, diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD index d64c8eb9ce122fa277567b2fbc632abfbc72df64..d99a519112787bad664232983208279cfb4d0036 100644 --- a/tensorflow/contrib/eager/python/examples/gan/BUILD +++ b/tensorflow/contrib/eager/python/examples/gan/BUILD @@ -9,6 +9,13 @@ py_binary( name = "mnist", srcs = ["mnist.py"], srcs_version = "PY2AND3", + deps = [":mnist_lib"], +) + +py_library( + name = "mnist_lib", + srcs = ["mnist.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -20,7 +27,7 @@ cuda_py_test( name = "mnist_test", srcs = ["mnist_test.py"], additional_deps = [ - ":mnist", + ":mnist_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], @@ -30,7 +37,7 @@ cuda_py_test( name = "mnist_graph_test", srcs = ["mnist_graph_test.py"], additional_deps = [ - ":mnist", + ":mnist_lib", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py index 12b39b0cde49d4c017acfa74572c725036c54eff..e73841fbf724e05eaa3be90cc8650f795d3e1ccf 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py @@ -42,7 +42,8 @@ class MnistGraphGanBenchmark(tf.test.Benchmark): # Generate some random data. images_data = np.random.randn(batch_size, 784).astype(np.float32) dataset = tf.data.Dataset.from_tensors(images_data) - images = dataset.repeat().make_one_shot_iterator().get_next() + images = tf.compat.v1.data.make_one_shot_iterator( + dataset.repeat()).get_next() # Create the models and optimizers generator = mnist.Generator(data_format()) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb index ca27a85a229d41a85fa26ecdc982da478fe9e202..e1a02db76f705414a34d232022f50124a5a6a3ed 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -13,11 +13,13 @@ "\n", "# Convolutional VAE: An example with tf.keras and eager\n", "\n", + "This example has moved:\n", + "\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cvae.ipynb\"\u003e\n", " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { @@ -28,604 +30,14 @@ }, "source": [ "![evolution of output during training](https://tensorflow.org/images/autoencoders/cvae.gif)\n", - "\n", - "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) by training a Variational Autoencoder. (VAE, [[1]](https://arxiv.org/abs/1312.6114), [[2]](https://arxiv.org/abs/1401.4082)).\n", "\n" ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "P-JuIu2N_SQf" - }, - "outputs": [], - "source": [ - "# to generate gifs\n", - "!pip install imageio" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "e1_Y75QXJS6h" - }, - "source": [ - "## Import TensorFlow and enable Eager execution" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "YfIk2es3hJEd" - }, - "outputs": [], - "source": [ - "from __future__ import absolute_import, division, print_function\n", - "\n", - "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", - "import tensorflow as tf\n", - "tfe = tf.contrib.eager\n", - "tf.enable_eager_execution()\n", - "\n", - "import os\n", - "import time\n", - "import numpy as np\n", - "import glob\n", - "import matplotlib.pyplot as plt\n", - "import PIL\n", - "import imageio\n", - "from IPython import display" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "iYn4MdZnKCey" - }, - "source": [ - "## Load the MNIST dataset\n", - "Each MNIST image is originally a vector of 784 integers, each of which is between 0-255 and represents the intensity of a pixel. We model each pixel with a Bernoulli distribution in our model, and we statically binarize the dataset." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "a4fYMGxGhrna" - }, - "outputs": [], - "source": [ - "(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "NFC2ghIdiZYE" - }, - "outputs": [], - "source": [ - "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", - "test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')\n", - "\n", - "# Normalizing the images to the range of [0., 1.]\n", - "train_images /= 255.\n", - "test_images /= 255.\n", - "\n", - "# Binarization\n", - "train_images[train_images \u003e= .5] = 1.\n", - "train_images[train_images \u003c .5] = 0.\n", - "test_images[test_images \u003e= .5] = 1.\n", - "test_images[test_images \u003c .5] = 0." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "S4PIDhoDLbsZ" - }, - "outputs": [], - "source": [ - "TRAIN_BUF = 60000\n", - "BATCH_SIZE = 100\n", - "\n", - "TEST_BUF = 10000" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "PIGN6ouoQxt3" - }, - "source": [ - "## Use *tf.data* to create batches and shuffle the dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "-yKCCQOoJ7cn" - }, - "outputs": [], - "source": [ - "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n", - "test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "THY-sZMiQ4UV" - }, - "source": [ - "## Wire up the generative and inference network with *tf.keras.Sequential*\n", - "\n", - "In our VAE example, we use two small ConvNets for the generative and inference network. Since these neural nets are small, we use `tf.keras.Sequential` to simplify our code. Let $x$ and $z$ denote the observation and latent variable respectively in the following descriptions. \n", - "\n", - "### Generative Network\n", - "This defines the generative model which takes a latent encoding as input, and outputs the parameters for a conditional distribution of the observation, i.e. $p(x|z)$. Additionally, we use a unit Gaussian prior $p(z)$ for the latent variable.\n", - "\n", - "### Inference Network\n", - "This defines an approximate posterior distribution $q(z|x)$, which takes as input an observation and outputs a set of parameters for the conditional distribution of the latent representation. In this example, we simply model this distribution as a diagonal Gaussian. In this case, the inference network outputs the mean and log-variance parameters of a factorized Gaussian (log-variance instead of the variance directly is for numerical stability).\n", - "\n", - "### Reparameterization Trick\n", - "During optimization, we can sample from $q(z|x)$ by first sampling from a unit Gaussian, and then multiplying by the standard deviation and adding the mean. This ensures the gradients could pass through the sample to the inference network parameters.\n", - "\n", - "### Network architecture\n", - "For the inference network, we use two convolutional layers followed by a fully-connected layer. In the generative network, we mirror this architecture by using a fully-connected layer followed by three convolution transpose layers (a.k.a. deconvolutional layers in some contexts). Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "VGLbvBEmjK0a" - }, - "outputs": [], - "source": [ - "class CVAE(tf.keras.Model):\n", - " def __init__(self, latent_dim):\n", - " super(CVAE, self).__init__()\n", - " self.latent_dim = latent_dim\n", - " self.inference_net = tf.keras.Sequential(\n", - " [\n", - " tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n", - " tf.keras.layers.Conv2D(\n", - " filters=32, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", - " tf.keras.layers.Conv2D(\n", - " filters=64, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", - " tf.keras.layers.Flatten(),\n", - " # No activation\n", - " tf.keras.layers.Dense(latent_dim + latent_dim),\n", - " ]\n", - " )\n", - "\n", - " self.generative_net = tf.keras.Sequential(\n", - " [\n", - " tf.keras.layers.InputLayer(input_shape=(latent_dim,)),\n", - " tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),\n", - " tf.keras.layers.Reshape(target_shape=(7, 7, 32)),\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=64,\n", - " kernel_size=3,\n", - " strides=(2, 2),\n", - " padding=\"SAME\",\n", - " activation=tf.nn.relu),\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=32,\n", - " kernel_size=3,\n", - " strides=(2, 2),\n", - " padding=\"SAME\",\n", - " activation=tf.nn.relu),\n", - " # No activation\n", - " tf.keras.layers.Conv2DTranspose(\n", - " filters=1, kernel_size=3, strides=(1, 1), padding=\"SAME\"),\n", - " ]\n", - " )\n", - "\n", - " def sample(self, eps=None):\n", - " if eps is None:\n", - " eps = tf.random_normal(shape=(100, self.latent_dim))\n", - " return self.decode(eps, apply_sigmoid=True)\n", - "\n", - " def encode(self, x):\n", - " mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)\n", - " return mean, logvar\n", - "\n", - " def reparameterize(self, mean, logvar):\n", - " eps = tf.random_normal(shape=mean.shape)\n", - " return eps * tf.exp(logvar * .5) + mean\n", - "\n", - " def decode(self, z, apply_sigmoid=False):\n", - " logits = self.generative_net(z)\n", - " if apply_sigmoid:\n", - " probs = tf.sigmoid(logits)\n", - " return probs\n", - "\n", - " return logits" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "0FMYgY_mPfTi" - }, - "source": [ - "## Define the loss function and the optimizer\n", - "\n", - "VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood:\n", - "\n", - "$$\\log p(x) \\ge \\text{ELBO} = \\mathbb{E}_{q(z|x)}\\left[\\log \\frac{p(x, z)}{q(z|x)}\\right].$$\n", - "\n", - "In practice, we optimize the single sample Monte Carlo estimate of this expectation:\n", - "\n", - "$$\\log p(x| z) + \\log p(z) - \\log q(z|x),$$\n", - "where $z$ is sampled from $q(z|x)$.\n", - "\n", - "**Note**: we could also analytically compute the KL term, but here we incorporate all three terms in the Monte Carlo estimator for simplicity." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "iWCn_PVdEJZ7" - }, - "outputs": [], - "source": [ - "def log_normal_pdf(sample, mean, logvar, raxis=1):\n", - " log2pi = tf.log(2. * np.pi)\n", - " return tf.reduce_sum(\n", - " -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),\n", - " axis=raxis)\n", - "\n", - "def compute_loss(model, x):\n", - " mean, logvar = model.encode(x)\n", - " z = model.reparameterize(mean, logvar)\n", - " x_logit = model.decode(z)\n", - "\n", - " cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)\n", - " logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])\n", - " logpz = log_normal_pdf(z, 0., 0.)\n", - " logqz_x = log_normal_pdf(z, mean, logvar)\n", - " return -tf.reduce_mean(logpx_z + logpz - logqz_x)\n", - "\n", - "def compute_gradients(model, x):\n", - " with tf.GradientTape() as tape:\n", - " loss = compute_loss(model, x)\n", - " return tape.gradient(loss, model.trainable_variables), loss\n", - "\n", - "optimizer = tf.train.AdamOptimizer(1e-4)\n", - "def apply_gradients(optimizer, gradients, variables, global_step=None):\n", - " optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Rw1fkAczTQYh" - }, - "source": [ - "## Training\n", - "\n", - "* We start by iterating over the dataset\n", - "* During each iteration, we pass the image to the encoder to obtain a set of mean and log-variance parameters of the approximate posterior $q(z|x)$\n", - "* We then apply the *reparameterization trick* to sample from $q(z|x)$\n", - "* Finally, we pass the reparameterized samples to the decoder to obtain the logits of the generative distribution $p(x|z)$\n", - "* **Note:** Since we use the dataset loaded by keras with 60k datapoints in the training set and 10k datapoints in the test set, our resulting ELBO on the test set is slightly higher than reported results in the literature which uses dynamic binarization of Larochelle's MNIST.\n", - "\n", - "## Generate Images\n", - "\n", - "* After training, it is time to generate some images\n", - "* We start by sampling a set of latent vectors from the unit Gaussian prior distribution $p(z)$\n", - "* The generator will then convert the latent sample $z$ to logits of the observation, giving a distribution $p(x|z)$\n", - "* Here we plot the probabilities of Bernoulli distributions\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "NS2GWywBbAWo" - }, - "outputs": [], - "source": [ - "epochs = 100\n", - "latent_dim = 50\n", - "num_examples_to_generate = 16\n", - "\n", - "# keeping the random vector constant for generation (prediction) so\n", - "# it will be easier to see the improvement.\n", - "random_vector_for_generation = tf.random_normal(\n", - " shape=[num_examples_to_generate, latent_dim])\n", - "model = CVAE(latent_dim)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RmdVsmvhPxyy" - }, - "outputs": [], - "source": [ - "def generate_and_save_images(model, epoch, test_input):\n", - " predictions = model.sample(test_input)\n", - " fig = plt.figure(figsize=(4,4))\n", - "\n", - " for i in range(predictions.shape[0]):\n", - " plt.subplot(4, 4, i+1)\n", - " plt.imshow(predictions[i, :, :, 0], cmap='gray')\n", - " plt.axis('off')\n", - "\n", - " # tight_layout minimizes the overlap between 2 sub-plots\n", - " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "2M7LmLtGEMQJ" - }, - "outputs": [], - "source": [ - "generate_and_save_images(model, 0, random_vector_for_generation)\n", - "\n", - "for epoch in range(1, epochs + 1):\n", - " start_time = time.time()\n", - " for train_x in train_dataset:\n", - " gradients, loss = compute_gradients(model, train_x)\n", - " apply_gradients(optimizer, gradients, model.trainable_variables)\n", - " end_time = time.time()\n", - "\n", - " if epoch % 1 == 0:\n", - " loss = tfe.metrics.Mean()\n", - " for test_x in test_dataset.make_one_shot_iterator():\n", - " loss(compute_loss(model, test_x))\n", - " elbo = -loss.result()\n", - " display.clear_output(wait=False)\n", - " print('Epoch: {}, Test set ELBO: {}, '\n", - " 'time elapse for current epoch {}'.format(epoch,\n", - " elbo,\n", - " end_time - start_time))\n", - " generate_and_save_images(\n", - " model, epoch, random_vector_for_generation)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "P4M_vIbUi7c0" - }, - "source": [ - "### Display an image using the epoch number" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "WfO5wCdclHGL" - }, - "outputs": [], - "source": [ - "def display_image(epoch_no):\n", - " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "5x3q9_Oe5q0A" - }, - "outputs": [], - "source": [ - "display_image(epochs) # Display images" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "NywiH3nL8guF" - }, - "source": [ - "### Generate a GIF of all the saved images." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "IGKQgENQ8lEI" - }, - "outputs": [], - "source": [ - "with imageio.get_writer('cvae.gif', mode='I') as writer:\n", - " filenames = glob.glob('image*.png')\n", - " filenames = sorted(filenames)\n", - " last = -1\n", - " for i,filename in enumerate(filenames):\n", - " frame = 2*(i**0.5)\n", - " if round(frame) \u003e round(last):\n", - " last = frame\n", - " else:\n", - " continue\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " \n", - "# this is a hack to display the gif inside the notebook\n", - "os.system('cp cvae.gif cvae.gif.png')" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "uV0yiKpzNP1b" - }, - "outputs": [], - "source": [ - "display.Image(filename=\"cvae.gif.png\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "yQXO_dlXkKsT" - }, - "source": [ - "To downlod the animation from Colab uncomment the code below:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "4fSJS3m5HLFM" - }, - "outputs": [], - "source": [ - "#from google.colab import files\n", - "#files.download('cvae.gif')" - ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], - "default_view": {}, "name": "cvae.ipynb", "private_outputs": true, "provenance": [ @@ -635,8 +47,7 @@ } ], "toc_visible": true, - "version": "0.3.2", - "views": {} + "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb index 78fcd397087fd1fd64aebed7ac3b5c6b2f45c450..53767058838459e56215d286e9f8f8eb66287147 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -1,26 +1,11 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "dcgan.ipynb", - "version": "0.3.2", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python2", - "display_name": "Python 2" - }, - "accelerator": "GPU" - }, "cells": [ { + "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0TD5ZrvEMbhZ" }, - "cell_type": "markdown", "source": [ "**Copyright 2018 The TensorFlow Authors**.\n", "\n", @@ -28,851 +13,39 @@ "\n", "# Generating Handwritten Digits with DCGAN\n", "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "ITZuApL56Mny" - }, - "cell_type": "markdown", - "source": [ - "This tutorial demonstrates how to generate images of handwritten digits using a Deep Convolutional Generative Adversarial Network ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)). The code is written in [tf.keras](https://www.tensorflow.org/programmers_guide/keras) with [eager execution](https://www.tensorflow.org/programmers_guide/eager) enabled. " - ] - }, - { - "metadata": { - "colab_type": "toc", - "id": "x2McrO9bMyLN" - }, - "cell_type": "markdown", - "source": [ - ">[Generating Handwritten Digits with DCGAN](#scrollTo=0TD5ZrvEMbhZ)\n", - "\n", - ">>[What are GANs?](#scrollTo=2MbKJY38Puy9)\n", - "\n", - ">>>[Import TensorFlow and enable eager execution](#scrollTo=e1_Y75QXJS6h)\n", - "\n", - ">>>[Load the dataset](#scrollTo=iYn4MdZnKCey)\n", - "\n", - ">>>[Use tf.data to create batches and shuffle the dataset](#scrollTo=PIGN6ouoQxt3)\n", - "\n", - ">>[Create the models](#scrollTo=THY-sZMiQ4UV)\n", - "\n", - ">>>[The Generator Model](#scrollTo=-tEyxE-GMC48)\n", - "\n", - ">>>[The Discriminator model](#scrollTo=D0IKnaCtg6WE)\n", - "\n", - ">>[Define the loss functions and the optimizer](#scrollTo=0FMYgY_mPfTi)\n", - "\n", - ">>>[Generator loss](#scrollTo=Jd-3GCUEiKtv)\n", - "\n", - ">>>[Discriminator loss](#scrollTo=PKY_iPSPNWoj)\n", - "\n", - ">>[Set up GANs for Training](#scrollTo=Rw1fkAczTQYh)\n", - "\n", - ">>[Train the GANs](#scrollTo=dZrd4CdjR-Fp)\n", - "\n", - ">>[Generated images](#scrollTo=P4M_vIbUi7c0)\n", + "This example has moved.\n", "\n", - ">>[Learn more about GANs](#scrollTo=k6qC-SbjK0yW)\n", - "\n" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/dcgan.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/blob/master/site/en/r2/tutorials/generative/dcgan.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { + "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "2MbKJY38Puy9" }, - "cell_type": "markdown", "source": [ - "## What are GANs?\n", - "GANs, or [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661), are a framework for estimating generative models. Two models are trained simultaneously by an adversarial process: a Generator, which is responsible for generating data (say, images), and a Discriminator, which is responsible for estimating the probability that an image was drawn from the training data (the image is real), or was produced by the Generator (the image is fake). During training, the Generator becomes progressively better at generating images, until the Discriminator is no longer able to distinguish real images from fake. \n", - "\n", - "![alt text](https://github.com/margaretmz/tensorflow/blob/margaret-dcgan/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png?raw=1)\n", - "\n", - "We will demonstrate this process end-to-end on MNIST. Below is an animation that shows a series of images produced by the Generator as it was trained for 50 epochs. Overtime, the generated images become increasingly difficult to distinguish from the training set.\n", - "\n", - "To learn more about GANs, we recommend MIT's [Intro to Deep Learning](http://introtodeeplearning.com/) course, which includes a lecture on Deep Generative Models ([video](https://youtu.be/JVb54xhEw6Y) | [slides](http://introtodeeplearning.com/materials/2018_6S191_Lecture4.pdf)). Now, let's head to the code!\n", - "\n", "![sample output](https://tensorflow.org/images/gan/dcgan.gif)" ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "dcgan.ipynb", + "provenance": [], + "version": "0.3.2" }, - { - "metadata": { - "colab_type": "code", - "id": "u_2z-B3piVsw", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Install imgeio in order to generate an animated gif showing the image generating process\n", - "!pip install imageio" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "e1_Y75QXJS6h" - }, - "cell_type": "markdown", - "source": [ - "### Import TensorFlow and enable eager execution" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "YfIk2es3hJEd", - "colab": {} - }, - "cell_type": "code", - "source": [ - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "import glob\n", - "import imageio\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import os\n", - "import PIL\n", - "import time\n", - "\n", - "from IPython import display" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "iYn4MdZnKCey" - }, - "cell_type": "markdown", - "source": [ - "### Load the dataset\n", - "\n", - "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will generate handwritten digits resembling the MNIST data." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "a4fYMGxGhrna", - "colab": {} - }, - "cell_type": "code", - "source": [ - "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "NFC2ghIdiZYE", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", - "train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "S4PIDhoDLbsZ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "BUFFER_SIZE = 60000\n", - "BATCH_SIZE = 256" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "PIGN6ouoQxt3" - }, - "cell_type": "markdown", - "source": [ - "### Use tf.data to create batches and shuffle the dataset" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "-yKCCQOoJ7cn", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "THY-sZMiQ4UV" - }, - "cell_type": "markdown", - "source": [ - "## Create the models\n", - "\n", - "We will use tf.keras [Sequential API](https://www.tensorflow.org/guide/keras#sequential_model) to define the generator and discriminator models." - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "-tEyxE-GMC48" - }, - "cell_type": "markdown", - "source": [ - "### The Generator Model\n", - "\n", - "The generator is responsible for creating convincing images that are good enough to fool the discriminator. The network architecture for the generator consists of [Conv2DTranspose](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2DTranspose) (Upsampling) layers. We start with a fully connected layer and upsample the image two times in order to reach the desired image size of 28x28x1. We increase the width and height, and reduce the depth as we move through the layers in the network. We use [Leaky ReLU](https://www.tensorflow.org/api_docs/python/tf/keras/layers/LeakyReLU) activation for each layer except for the last one where we use a tanh activation." - ] - }, - { - "metadata": { - "id": "6bpTcDqoLWjY", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def make_generator_model():\n", - " model = tf.keras.Sequential()\n", - " model.add(tf.keras.layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))\n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " \n", - " model.add(tf.keras.layers.Reshape((7, 7, 256)))\n", - " assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size\n", - " \n", - " model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))\n", - " assert model.output_shape == (None, 7, 7, 128) \n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - "\n", - " model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))\n", - " assert model.output_shape == (None, 14, 14, 64) \n", - " model.add(tf.keras.layers.BatchNormalization())\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - "\n", - " model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))\n", - " assert model.output_shape == (None, 28, 28, 1)\n", - " \n", - " return model" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "D0IKnaCtg6WE" - }, - "cell_type": "markdown", - "source": [ - "### The Discriminator model\n", - "\n", - "The discriminator is responsible for distinguishing fake images from real images. It's similar to a regular CNN-based image classifier." - ] - }, - { - "metadata": { - "id": "dw2tPLmk2pEP", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def make_discriminator_model():\n", - " model = tf.keras.Sequential()\n", - " model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " model.add(tf.keras.layers.Dropout(0.3))\n", - " \n", - " model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))\n", - " model.add(tf.keras.layers.LeakyReLU())\n", - " model.add(tf.keras.layers.Dropout(0.3))\n", - " \n", - " model.add(tf.keras.layers.Flatten())\n", - " model.add(tf.keras.layers.Dense(1))\n", - " \n", - " return model" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "gDkA05NE6QMs", - "colab": {} - }, - "cell_type": "code", - "source": [ - "generator = make_generator_model()\n", - "discriminator = make_discriminator_model()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "0FMYgY_mPfTi" - }, - "cell_type": "markdown", - "source": [ - "## Define the loss functions and the optimizer\n", - "\n", - "Let's define the loss functions and the optimizers for the generator and the discriminator.\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "Jd-3GCUEiKtv" - }, - "cell_type": "markdown", - "source": [ - "### Generator loss\n", - "The generator loss is a sigmoid cross entropy loss of the generated images and an array of ones, since the generator is trying to generate fake images that resemble the real images." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "90BIcCKcDMxz", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def generator_loss(generated_output):\n", - " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "PKY_iPSPNWoj" - }, - "cell_type": "markdown", - "source": [ - "### Discriminator loss\n", - "\n", - "The discriminator loss function takes two inputs: real images, and generated images. Here is how to calculate the discriminator loss:\n", - "1. Calculate real_loss which is a sigmoid cross entropy loss of the real images and an array of ones (since these are the real images).\n", - "2. Calculate generated_loss which is a sigmoid cross entropy loss of the generated images and an array of zeros (since these are the fake images).\n", - "3. Calculate the total_loss as the sum of real_loss and generated_loss." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "wkMNfBWlT-PV", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def discriminator_loss(real_output, generated_output):\n", - " # [1,1,...,1] with real output since it is true and we want our generated examples to look like it\n", - " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)\n", - "\n", - " # [0,0,...,0] with generated images since they are fake\n", - " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)\n", - "\n", - " total_loss = real_loss + generated_loss\n", - "\n", - " return total_loss" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "MgIc7i0th_Iu" - }, - "cell_type": "markdown", - "source": [ - "The discriminator and the generator optimizers are different since we will train two networks separately." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "iWCn_PVdEJZ7", - "colab": {} - }, - "cell_type": "code", - "source": [ - "generator_optimizer = tf.train.AdamOptimizer(1e-4)\n", - "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "mWtinsGDPJlV" - }, - "cell_type": "markdown", - "source": [ - "**Checkpoints (Object-based saving)**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "CA1w-7s2POEy", - "colab": {} - }, - "cell_type": "code", - "source": [ - "checkpoint_dir = './training_checkpoints'\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", - "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", - " discriminator_optimizer=discriminator_optimizer,\n", - " generator=generator,\n", - " discriminator=discriminator)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "Rw1fkAczTQYh" - }, - "cell_type": "markdown", - "source": [ - "## Set up GANs for Training\n", - "\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "5QC5BABamh_c" - }, - "cell_type": "markdown", - "source": [ - "Now it's time to put together the generator and discriminator to set up the Generative Adversarial Networks, as you see in the diagam at the beginning of the tutorial." - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "Ff6oN6PZX27n" - }, - "cell_type": "markdown", - "source": [ - "**Define training parameters**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "NS2GWywBbAWo", - "colab": {} - }, - "cell_type": "code", - "source": [ - "EPOCHS = 50\n", - "noise_dim = 100\n", - "num_examples_to_generate = 16\n", - "\n", - "# We'll re-use this random vector used to seed the generator so\n", - "# it will be easier to see the improvement over time.\n", - "random_vector_for_generation = tf.random_normal([num_examples_to_generate,\n", - " noise_dim])" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "jylSonrqSWfi" - }, - "cell_type": "markdown", - "source": [ - "**Define training method**\n", - "\n", - "We start by iterating over the dataset. The generator is given a random vector as an input which is processed to output an image looking like a handwritten digit. The discriminator is then shown the real MNIST images as well as the generated images.\n", - "\n", - "Next, we calculate the generator and the discriminator loss. Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables." - ] - }, - { - "metadata": { - "id": "3t5ibNo05jCB", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def train_step(images):\n", - " # generating noise from a normal distribution\n", - " noise = tf.random_normal([BATCH_SIZE, noise_dim])\n", - " \n", - " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", - " generated_images = generator(noise, training=True)\n", - " \n", - " real_output = discriminator(images, training=True)\n", - " generated_output = discriminator(generated_images, training=True)\n", - " \n", - " gen_loss = generator_loss(generated_output)\n", - " disc_loss = discriminator_loss(real_output, generated_output)\n", - " \n", - " gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)\n", - " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)\n", - " \n", - " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))\n", - " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "6TSZgwc2BUQ-" - }, - "cell_type": "markdown", - "source": [ - "\n", - "This model takes about ~30 seconds per epoch to train on a single Tesla K80 on Colab, as of October 2018. \n", - "\n", - "Eager execution can be slower than executing the equivalent graph as it can't benefit from whole-program optimizations on the graph, and also incurs overheads of interpreting Python code. By using [tf.contrib.eager.defun](https://www.tensorflow.org/api_docs/python/tf/contrib/eager/defun) to create graph functions, we get a ~20 secs/epoch performance boost (from ~50 secs/epoch down to ~30 secs/epoch). This way we get the best of both eager execution (easier for debugging) and graph mode (better performance)." - ] - }, - { - "metadata": { - "id": "Iwya07_j5p2A", - "colab_type": "code", - "colab": {} - }, - "cell_type": "code", - "source": [ - "train_step = tf.contrib.eager.defun(train_step)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "2M7LmLtGEMQJ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def train(dataset, epochs): \n", - " for epoch in range(epochs):\n", - " start = time.time()\n", - " \n", - " for images in dataset:\n", - " train_step(images)\n", - "\n", - " display.clear_output(wait=True)\n", - " generate_and_save_images(generator,\n", - " epoch + 1,\n", - " random_vector_for_generation)\n", - " \n", - " # saving (checkpoint) the model every 15 epochs\n", - " if (epoch + 1) % 15 == 0:\n", - " checkpoint.save(file_prefix = checkpoint_prefix)\n", - " \n", - " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n", - " time.time()-start))\n", - " # generating after the final epoch\n", - " display.clear_output(wait=True)\n", - " generate_and_save_images(generator,\n", - " epochs,\n", - " random_vector_for_generation)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "2aFF7Hk3XdeW" - }, - "cell_type": "markdown", - "source": [ - "**Generate and save images**\n", - "\n" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "RmdVsmvhPxyy", - "colab": {} - }, - "cell_type": "code", - "source": [ - "def generate_and_save_images(model, epoch, test_input):\n", - " # make sure the training parameter is set to False because we\n", - " # don't want to train the batchnorm layer when doing inference.\n", - " predictions = model(test_input, training=False)\n", - "\n", - " fig = plt.figure(figsize=(4,4))\n", - " \n", - " for i in range(predictions.shape[0]):\n", - " plt.subplot(4, 4, i+1)\n", - " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", - " plt.axis('off')\n", - " \n", - " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", - " plt.show()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "dZrd4CdjR-Fp" - }, - "cell_type": "markdown", - "source": [ - "## Train the GANs\n", - "We will call the train() method defined above to train the generator and discriminator simultaneously. Note, training GANs can be tricky. It's important that the generator and discriminator do not overpower each other (e.g., that they train at a similar rate).\n", - "\n", - "At the beginning of the training, the generated images look like random noise. As training progresses, you can see the generated digits look increasingly real. After 50 epochs, they look very much like the MNIST digits." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "Ly3UN0SLLY2l", - "colab": {} - }, - "cell_type": "code", - "source": [ - "%%time\n", - "train(train_dataset, EPOCHS)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "rfM4YcPVPkNO" - }, - "cell_type": "markdown", - "source": [ - "**Restore the latest checkpoint**" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "XhXsd0srPo8c", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# restoring the latest checkpoint in checkpoint_dir\n", - "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "P4M_vIbUi7c0" - }, - "cell_type": "markdown", - "source": [ - "## Generated images \n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "mLskt7EfXAjr" - }, - "cell_type": "markdown", - "source": [ - "\n", - "After training, its time to generate some images! \n", - "The last step is to plot the generated images and voila!\n" - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "WfO5wCdclHGL", - "colab": {} - }, - "cell_type": "code", - "source": [ - "# Display a single image using the epoch number\n", - "def display_image(epoch_no):\n", - " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "code", - "id": "5x3q9_Oe5q0A", - "colab": {} - }, - "cell_type": "code", - "source": [ - "display_image(EPOCHS)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "NywiH3nL8guF" - }, - "cell_type": "markdown", - "source": [ - "**Generate a GIF of all the saved images**\n", - "\n", - "We will use imageio to create an animated gif using all the images saved during training." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "IGKQgENQ8lEI", - "colab": {} - }, - "cell_type": "code", - "source": [ - "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n", - " filenames = glob.glob('image*.png')\n", - " filenames = sorted(filenames)\n", - " last = -1\n", - " for i,filename in enumerate(filenames):\n", - " frame = 2*(i**0.5)\n", - " if round(frame) > round(last):\n", - " last = frame\n", - " else:\n", - " continue\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " image = imageio.imread(filename)\n", - " writer.append_data(image)\n", - " \n", - "# this is a hack to display the gif inside the notebook\n", - "os.system('cp dcgan.gif dcgan.gif.png')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "cGhC3-fMWSwl" - }, - "cell_type": "markdown", - "source": [ - "Display the animated gif with all the mages generated during the training of GANs." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "uV0yiKpzNP1b", - "colab": {} - }, - "cell_type": "code", - "source": [ - "display.Image(filename=\"dcgan.gif.png\")" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "6EEG-wePkmJQ" - }, - "cell_type": "markdown", - "source": [ - "**Download the animated gif**\n", - "\n", - "Uncomment the code below to download an animated gif from Colab." - ] - }, - { - "metadata": { - "colab_type": "code", - "id": "4UJjSnIMOzOJ", - "colab": {} - }, - "cell_type": "code", - "source": [ - "#from google.colab import files\n", - "#files.download('dcgan.gif')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "colab_type": "text", - "id": "k6qC-SbjK0yW" - }, - "cell_type": "markdown", - "source": [ - "## Learn more about GANs\n" - ] - }, - { - "metadata": { - "colab_type": "text", - "id": "xjjkT9KAK6H7" - }, - "cell_type": "markdown", - "source": [ - "We hope this tutorial was helpful! As a next step, you might like to experiment with a different dataset, for example the Large-scale Celeb Faces Attributes (CelebA) dataset [available on Kaggle](https://www.kaggle.com/jessicali9530/celeba-dataset/home).\n", - "\n", - "To learn more about GANs:\n", - "\n", - "* Check out MIT's lecture (linked above), or [this](http://cs231n.stanford.edu/slides/2018/cs231n_2018_lecture12.pdf) lecture form Stanford's CS231n. \n", - "\n", - "* We also recommend the [CVPR 2018 Tutorial on GANs](https://sites.google.com/view/cvpr2018tutorialongans/), and the [NIPS 2016 Tutorial: Generative Adversarial Networks](https://arxiv.org/abs/1701.00160).\n" - ] + "kernelspec": { + "display_name": "Python 2", + "name": "python2" } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png b/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png deleted file mode 100644 index b715bd83ef117641c6429e0ac173dbe9b8d5fd88..0000000000000000000000000000000000000000 Binary files a/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png and /dev/null differ diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 3acecd283cda83992bab0c37cf0b8037ed2cf27a..979772acd3f823a8cc53ab5e026946ad3bb19353 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1,36 +1,11 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "image_captioning_with_attention.ipynb", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [ - { - "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", - "timestamp": 1530222436922 - } - ], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "accelerator": "GPU" - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "K2s1A9eLRPEj", - "colab_type": "text" + "colab_type": "text", + "id": "K2s1A9eLRPEj" }, - "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors.\n", "\n", @@ -38,1147 +13,59 @@ ] }, { - "metadata": { - "id": "Cffg2i257iMS", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Image Captioning with Attention\n", - "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] - }, - { - "metadata": { - "id": "QASbY_HGo4Lq", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "Image captioning is the task of generating a caption for an image. Given an image like this:\n", - "\n", - "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", - "\n", - "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", - "\n", - "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", - "\n", - "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", - "\n", - "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", - "\n", - "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", - "\n", - "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", - "\n", - "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", - "\n", - "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" - ] - }, - { - "metadata": { - "id": "U8l4RJ0XRPEm", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Import TensorFlow and enable eager execution\n", - "# This code requires TensorFlow version >=1.9\n", - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "# We'll generate plots of attention in order to see which parts of an image\n", - "# our model focuses on during captioning\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Scikit-learn includes many helpful utilities\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.utils import shuffle\n", - "\n", - "import re\n", - "import numpy as np\n", - "import os\n", - "import time\n", - "import json\n", - "from glob import glob\n", - "from PIL import Image\n", - "import pickle" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "b6qbGw8MRPE5", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Download and prepare the MS-COCO dataset\n", - "\n", - "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", - "\n", - "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." - ] - }, - { - "metadata": { - "id": "krQuPYTtRPE7", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", - " extract = True)\n", - "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", - "\n", - "name_of_zip = 'train2014.zip'\n", - "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", - " image_zip = tf.keras.utils.get_file(name_of_zip, \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", - " extract = True)\n", - " PATH = os.path.dirname(image_zip)+'/train2014/'\n", - "else:\n", - " PATH = os.path.abspath('.')+'/train2014/'" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "aANEzb5WwSzg", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Optionally, limit the size of the training set for faster training\n", - "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." - ] - }, - { - "metadata": { - "id": "4G3b8x8_RPFD", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# read the json file\n", - "with open(annotation_file, 'r') as f:\n", - " annotations = json.load(f)\n", - "\n", - "# storing the captions and the image name in vectors\n", - "all_captions = []\n", - "all_img_name_vector = []\n", - "\n", - "for annot in annotations['annotations']:\n", - " caption = ' ' + annot['caption'] + ' '\n", - " image_id = annot['image_id']\n", - " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", - " \n", - " all_img_name_vector.append(full_coco_image_path)\n", - " all_captions.append(caption)\n", - "\n", - "# shuffling the captions and image_names together\n", - "# setting a random state\n", - "train_captions, img_name_vector = shuffle(all_captions,\n", - " all_img_name_vector,\n", - " random_state=1)\n", - "\n", - "# selecting the first 30000 captions from the shuffled set\n", - "num_examples = 30000\n", - "train_captions = train_captions[:num_examples]\n", - "img_name_vector = img_name_vector[:num_examples]" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "mPBMgK34RPFL", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "len(train_captions), len(all_captions)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "8cSW4u-ORPFQ", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Preprocess the images using InceptionV3\n", - "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", - "\n", - "First, we will need to convert the images into the format inceptionV3 expects by:\n", - "* Resizing the image to (299, 299)\n", - "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." - ] - }, - { - "metadata": { - "id": "zXR0217aRPFR", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def load_image(image_path):\n", - " img = tf.read_file(image_path)\n", - " img = tf.image.decode_jpeg(img, channels=3)\n", - " img = tf.image.resize_images(img, (299, 299))\n", - " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", - " return img, image_path" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "MDvIu4sXRPFV", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", - "\n", - "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", - "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", - "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", - "* We avoid doing this during training so it does not become a bottleneck. \n", - "* After all the images are passed through the network, we pickle the dictionary and save it to disk." - ] - }, - { - "metadata": { - "id": "RD3vW4SsRPFW", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", - " weights='imagenet')\n", - "new_input = image_model.input\n", - "hidden_layer = image_model.layers[-1].output\n", - "\n", - "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "rERqlR3WRPGO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Caching the features extracted from InceptionV3\n", - "\n", - "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", - "\n", - "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", - "\n", - "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", - "\n", - "```for img, path in image_dataset:``` \n", - "\n", - "to:\n", - "\n", - "```for img, path in tqdm(image_dataset):```." - ] - }, - { - "metadata": { - "id": "Dx_fvbVgRPGQ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# getting the unique images\n", - "encode_train = sorted(set(img_name_vector))\n", - "\n", - "# feel free to change the batch_size according to your system configuration\n", - "image_dataset = tf.data.Dataset.from_tensor_slices(\n", - " encode_train).map(load_image).batch(16)\n", - "\n", - "for img, path in image_dataset:\n", - " batch_features = image_features_extract_model(img)\n", - " batch_features = tf.reshape(batch_features, \n", - " (batch_features.shape[0], -1, batch_features.shape[3]))\n", - "\n", - " for bf, p in zip(batch_features, path):\n", - " path_of_feature = p.numpy().decode(\"utf-8\")\n", - " np.save(path_of_feature, bf.numpy())" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "nyqH3zFwRPFi", - "colab_type": "text" - }, "cell_type": "markdown", - "source": [ - "## Preprocess and tokenize the captions\n", - "\n", - "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", - "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", - "* Finally, we create a word --> index mapping and vice-versa.\n", - "* We will then pad all sequences to the be same length as the longest one. " - ] - }, - { - "metadata": { - "id": "HZfK8RhQRPFj", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# This will find the maximum length of any caption in our dataset\n", - "def calc_max_length(tensor):\n", - " return max(len(t) for t in tensor)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "oJGE34aiRPFo", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# The steps above is a general process of dealing with text processing\n", - "\n", - "# choosing the top 5000 words from the vocabulary\n", - "top_k = 5000\n", - "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", - " oov_token=\"\", \n", - " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", - "tokenizer.fit_on_texts(train_captions)\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "8Q44tNQVRPFt", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "tokenizer.word_index = {key:value for key, value in tokenizer.word_index.items() if value <= top_k}\n", - "# putting token in the word2idx dictionary\n", - "tokenizer.word_index[tokenizer.oov_token] = top_k + 1\n", - "tokenizer.word_index[''] = 0" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "0fpJb5ojRPFv", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# creating the tokenized vectors\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "olQArbgbRPF1", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# creating a reverse mapping (index -> word)\n", - "index_word = {value:key for key, value in tokenizer.word_index.items()}" - ], - "execution_count": 0, - "outputs": [] - }, - { "metadata": { - "id": "AidglIZVRPF4", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "colab_type": "text", + "id": "Cffg2i257iMS" }, - "cell_type": "code", "source": [ - "# padding each vector to the max_length of the captions\n", - "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", - "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "gL0wkttkRPGA", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# calculating the max_length \n", - "# used to store the attention weights\n", - "max_length = calc_max_length(train_seqs)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "M3CD75nDpvTI", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Split the data into training and testing" - ] - }, - { - "metadata": { - "id": "iS7DDMszRPGF", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Create training and validation sets using 80-20 split\n", - "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", - " cap_vector, \n", - " test_size=0.2, \n", - " random_state=0)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "XmViPkRFRPGH", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "uEWM9xrYcg45", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", - "\n" - ] - }, - { - "metadata": { - "id": "Q3TnZ1ToRPGV", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# feel free to change these parameters according to your system's configuration\n", - "\n", - "BATCH_SIZE = 64\n", - "BUFFER_SIZE = 1000\n", - "embedding_dim = 256\n", - "units = 512\n", - "vocab_size = len(tokenizer.word_index)\n", - "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", - "# these two variables represent that\n", - "features_shape = 2048\n", - "attention_features_shape = 64" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "SmZS2N0bXG3T", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# loading the numpy files \n", - "def map_func(img_name, cap):\n", - " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", - " return img_tensor, cap" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "FDF_Nm3tRPGZ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", - "\n", - "# using map to load the numpy files in parallel\n", - "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", - "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", - "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", - " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", - "\n", - "# shuffling and batching\n", - "dataset = dataset.shuffle(BUFFER_SIZE)\n", - "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", - "dataset = dataset.batch(BATCH_SIZE)\n", - "dataset = dataset.prefetch(1)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "nrvoDphgRPGd", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Model\n", - "\n", - "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", - "\n", - "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", - "* We squash that to a shape of (64, 2048).\n", - "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", - "* The RNN(here GRU) attends over the image to predict the next word." - ] - }, - { - "metadata": { - "id": "AAppCGLKRPGd", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def gru(units):\n", - " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", - " # significant speedup).\n", - " if tf.test.is_gpu_available():\n", - " return tf.keras.layers.CuDNNGRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " return tf.keras.layers.GRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "ja2LFTMSdeV3", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class BahdanauAttention(tf.keras.Model):\n", - " def __init__(self, units):\n", - " super(BahdanauAttention, self).__init__()\n", - " self.W1 = tf.keras.layers.Dense(units)\n", - " self.W2 = tf.keras.layers.Dense(units)\n", - " self.V = tf.keras.layers.Dense(1)\n", - " \n", - " def call(self, features, hidden):\n", - " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", - " \n", - " # hidden shape == (batch_size, hidden_size)\n", - " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", - " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", - " \n", - " # score shape == (batch_size, 64, hidden_size)\n", - " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", - " \n", - " # attention_weights shape == (batch_size, 64, 1)\n", - " # we get 1 at the last axis because we are applying score to self.V\n", - " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", - " \n", - " # context_vector shape after sum == (batch_size, hidden_size)\n", - " context_vector = attention_weights * features\n", - " context_vector = tf.reduce_sum(context_vector, axis=1)\n", - " \n", - " return context_vector, attention_weights" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "AZ7R1RxHRPGf", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class CNN_Encoder(tf.keras.Model):\n", - " # Since we have already extracted the features and dumped it using pickle\n", - " # This encoder passes those features through a Fully connected layer\n", - " def __init__(self, embedding_dim):\n", - " super(CNN_Encoder, self).__init__()\n", - " # shape after fc == (batch_size, 64, embedding_dim)\n", - " self.fc = tf.keras.layers.Dense(embedding_dim)\n", - " \n", - " def call(self, x):\n", - " x = self.fc(x)\n", - " x = tf.nn.relu(x)\n", - " return x" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "V9UbGQmERPGi", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class RNN_Decoder(tf.keras.Model):\n", - " def __init__(self, embedding_dim, units, vocab_size):\n", - " super(RNN_Decoder, self).__init__()\n", - " self.units = units\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - " self.gru = gru(self.units)\n", - " self.fc1 = tf.keras.layers.Dense(self.units)\n", - " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " self.attention = BahdanauAttention(self.units)\n", - " \n", - " def call(self, x, features, hidden):\n", - " # defining attention as a separate model\n", - " context_vector, attention_weights = self.attention(features, hidden)\n", - " \n", - " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", - " x = self.embedding(x)\n", - " \n", - " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", - " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", - " \n", - " # passing the concatenated vector to the GRU\n", - " output, state = self.gru(x)\n", - " \n", - " # shape == (batch_size, max_length, hidden_size)\n", - " x = self.fc1(output)\n", - " \n", - " # x shape == (batch_size * max_length, hidden_size)\n", - " x = tf.reshape(x, (-1, x.shape[2]))\n", - " \n", - " # output shape == (batch_size * max_length, vocab)\n", - " x = self.fc2(x)\n", - "\n", - " return x, state, attention_weights\n", - "\n", - " def reset_state(self, batch_size):\n", - " return tf.zeros((batch_size, self.units))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "Qs_Sr03wRPGk", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "encoder = CNN_Encoder(embedding_dim)\n", - "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "-bYN7xA0RPGl", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", + "# Image Captioning with Attention\n", "\n", - "# We are masking the loss calculated for padding\n", - "def loss_function(real, pred):\n", - " mask = 1 - np.equal(real, 0)\n", - " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", - " return tf.reduce_mean(loss_)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "PHod7t72RPGn", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Training\n", + "This example has moved:\n", "\n", - "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", - "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", - "* The decoder returns the predictions and the decoder hidden state.\n", - "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", - "* Use teacher forcing to decide the next input to the decoder.\n", - "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", - "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/image_captioning.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/generative/image_captioning.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] }, { - "metadata": { - "id": "Vt4WZ5mhJE-E", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# adding this in a separate cell because if you run the training cell \n", - "# many times, the loss_plot array will be reset\n", - "loss_plot = []" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "UlA4VIQpRPGo", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " total_loss = 0\n", - " \n", - " for (batch, (img_tensor, target)) in enumerate(dataset):\n", - " loss = 0\n", - " \n", - " # initializing the hidden state for each batch\n", - " # because the captions are not related from image to image\n", - " hidden = decoder.reset_state(batch_size=target.shape[0])\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", - " \n", - " with tf.GradientTape() as tape:\n", - " features = encoder(img_tensor)\n", - " \n", - " for i in range(1, target.shape[1]):\n", - " # passing the features through the decoder\n", - " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", - "\n", - " loss += loss_function(target[:, i], predictions)\n", - " \n", - " # using teacher forcing\n", - " dec_input = tf.expand_dims(target[:, i], 1)\n", - " \n", - " total_loss += (loss / int(target.shape[1]))\n", - " \n", - " variables = encoder.variables + decoder.variables\n", - " \n", - " gradients = tape.gradient(loss, variables) \n", - " \n", - " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", - " \n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", - " batch, \n", - " loss.numpy() / int(target.shape[1])))\n", - " # storing the epoch end loss value to plot later\n", - " loss_plot.append(total_loss / len(cap_vector))\n", - " \n", - " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", - " total_loss/len(cap_vector)))\n", - " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "1Wm83G-ZBPcC", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "plt.plot(loss_plot)\n", - "plt.xlabel('Epochs')\n", - "plt.ylabel('Loss')\n", - "plt.title('Loss Plot')\n", - "plt.show()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "xGvOcLQKghXN", - "colab_type": "text" - }, "cell_type": "markdown", - "source": [ - "## Caption!\n", - "\n", - "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", - "* Stop predicting when the model predicts the end token.\n", - "* And store the attention weights for every time step." - ] - }, - { "metadata": { - "id": "RCWpDtyNRPGs", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "colab_type": "text", + "id": "QASbY_HGo4Lq" }, - "cell_type": "code", "source": [ - "def evaluate(image):\n", - " attention_plot = np.zeros((max_length, attention_features_shape))\n", - "\n", - " hidden = decoder.reset_state(batch_size=1)\n", - "\n", - " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", - " img_tensor_val = image_features_extract_model(temp_input)\n", - " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", - "\n", - " features = encoder(img_tensor_val)\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", - " result = []\n", - "\n", - " for i in range(max_length):\n", - " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", - "\n", - " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", - "\n", - " predicted_id = tf.argmax(predictions[0]).numpy()\n", - " result.append(index_word[predicted_id])\n", - "\n", - " if index_word[predicted_id] == '':\n", - " return result, attention_plot\n", - "\n", - " dec_input = tf.expand_dims([predicted_id], 0)\n", - "\n", - " attention_plot = attention_plot[:len(result), :]\n", - " return result, attention_plot" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "fD_y7PD6RPGt", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def plot_attention(image, result, attention_plot):\n", - " temp_image = np.array(Image.open(image))\n", - "\n", - " fig = plt.figure(figsize=(10, 10))\n", - " \n", - " len_result = len(result)\n", - " for l in range(len_result):\n", - " temp_att = np.resize(attention_plot[l], (8, 8))\n", - " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", - " ax.set_title(result[l])\n", - " img = ax.imshow(temp_image)\n", - " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", + "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", "\n", - " plt.tight_layout()\n", - " plt.show()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "io7ws3ReRPGv", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# captions on the validation set\n", - "rid = np.random.randint(0, len(img_name_val))\n", - "image = img_name_val[rid]\n", - "real_caption = ' '.join([index_word[i] for i in cap_val[rid] if i not in [0]])\n", - "result, attention_plot = evaluate(image)\n", + "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", "\n", - "print ('Real Caption:', real_caption)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image, result, attention_plot)\n", - "# opening the image\n", - "Image.open(img_name_val[rid])" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "Rprk3HEvZuxb", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Try it on your own images\n", - "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" + "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n" ] - }, - { - "metadata": { - "id": "9Psd1quzaAWg", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "image_captioning_with_attention.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", + "timestamp": 1530222436922 } - }, - "cell_type": "code", - "source": [ - "image_url = 'https://tensorflow.org/images/surf.jpg'\n", - "image_extension = image_url[-4:]\n", - "image_path = tf.keras.utils.get_file('image'+image_extension, \n", - " origin=image_url)\n", - "\n", - "result, attention_plot = evaluate(image_path)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image_path, result, attention_plot)\n", - "# opening the image\n", - "Image.open(image_path)" ], - "execution_count": 0, - "outputs": [] + "toc_visible": true, + "version": "0.3.2" }, - { - "metadata": { - "id": "VJZXyJco6uLO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Next steps\n", - "\n", - "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." - ] + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" } - ] + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index bda9e77085e45ae31a228142135425e22a1c6780..c945c753b3ba36d16aa6985d23a5849f8f552304 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -13,633 +13,13 @@ "\n", "# Text Generation using a RNN\n", "\n", + "This example has moved.\n", + "\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/sequences/text_generation.ipynb\"\u003e\n", " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on Github\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "BwpJ5IffzRG6" - }, - "source": [ - "This notebook demonstrates how to generate text using an RNN using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). If you like, you can write a similar [model](https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/8.1-text-generation-with-lstm.ipynb) using less code. Here, we show a lower-level impementation that's useful to understand as prework before diving in to deeper examples in a similar, like [Neural Machine Translation with Attention](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "This notebook is an end-to-end example. When you run it, it will download a dataset of Shakespeare's writing. We'll use a collection of plays, borrowed from Andrej Karpathy's excellent [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). The notebook will train a model, and use it to generate sample output.\n", - " \n", - "Here is the output(with start string='w') after training a single layer GRU for 30 epochs with the default settings below:\n", - "\n", - "```\n", - "were to the death of him\n", - "And nothing of the field in the view of hell,\n", - "When I said, banish him, I will not burn thee that would live.\n", - "\n", - "HENRY BOLINGBROKE:\n", - "My gracious uncle--\n", - "\n", - "DUKE OF YORK:\n", - "As much disgraced to the court, the gods them speak,\n", - "And now in peace himself excuse thee in the world.\n", - "\n", - "HORTENSIO:\n", - "Madam, 'tis not the cause of the counterfeit of the earth,\n", - "And leave me to the sun that set them on the earth\n", - "And leave the world and are revenged for thee.\n", - "\n", - "GLOUCESTER:\n", - "I would they were talking with the very name of means\n", - "To make a puppet of a guest, and therefore, good Grumio,\n", - "Nor arm'd to prison, o' the clouds, of the whole field,\n", - "With the admire\n", - "With the feeding of thy chair, and we have heard it so,\n", - "I thank you, sir, he is a visor friendship with your silly your bed.\n", - "\n", - "SAMPSON:\n", - "I do desire to live, I pray: some stand of the minds, make thee remedies\n", - "With the enemies of my soul.\n", - "\n", - "MENENIUS:\n", - "I'll keep the cause of my mistress.\n", - "\n", - "POLIXENES:\n", - "My brother Marcius!\n", - "\n", - "Second Servant:\n", - "Will't ple\n", - "```\n", - "\n", - "Of course, while some of the sentences are grammatical, most do not make sense. But, consider:\n", - "\n", - "* Our model is character based (when we began training, it did not yet know how to spell a valid English word, or that words were even a unit of text).\n", - "\n", - "* The structure of the output resembles a play (blocks begin with a speaker name, in all caps similar to the original text). Sentences generally end with a period. If you look at the text from a distance (or don't read the invididual words too closely, it appears as if it's an excerpt from a play).\n", - "\n", - "As a next step, you can experiment training the model on a different dataset - any large text file(ASCII) will do, and you can modify a single line of code below to make that change. Have fun!\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "R3p22DBDsaCA" - }, - "source": [ - "## Install unidecode library\n", - "A helpful library to convert unicode to ASCII." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "wZ6LOM12wKGH" - }, - "outputs": [], - "source": [ - "!pip install unidecode" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "WGyKZj3bzf9p" - }, - "source": [ - "## Import tensorflow and enable eager execution." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "yG_n40gFzf9s" - }, - "outputs": [], - "source": [ - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", - "import tensorflow as tf\n", - "\n", - "# Note: Once you enable eager execution, it cannot be disabled. \n", - "tf.enable_eager_execution()\n", - "\n", - "import numpy as np\n", - "import os\n", - "import re\n", - "import random\n", - "import unidecode\n", - "import time" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "EHDoRoc5PKWz" - }, - "source": [ - "## Download the dataset\n", - "\n", - "In this example, we will use the [shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt). You can use any other dataset that you like.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "pD_55cOxLkAb" - }, - "outputs": [], - "source": [ - "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "UHjdCjDuSvX_" - }, - "source": [ - "## Read the dataset\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "-E5JvY3wzf94" - }, - "outputs": [], - "source": [ - "text = unidecode.unidecode(open(path_to_file).read())\n", - "# length of text is the number of characters in it\n", - "print (len(text))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Il9ww98izf-D" - }, - "source": [ - "Creating dictionaries to map from characters to their indices and vice-versa, which will be used to vectorize the inputs" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "IalZLbvOzf-F" - }, - "outputs": [], - "source": [ - "# unique contains all the unique characters in the file\n", - "unique = sorted(set(text))\n", - "\n", - "# creating a mapping from unique characters to indices\n", - "char2idx = {u:i for i, u in enumerate(unique)}\n", - "idx2char = {i:u for i, u in enumerate(unique)}" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "1v_qUYfAzf-I" - }, - "outputs": [], - "source": [ - "# setting the maximum length sentence we want for a single input in characters\n", - "max_length = 100\n", - "\n", - "# length of the vocabulary in chars\n", - "vocab_size = len(unique)\n", - "\n", - "# the embedding dimension \n", - "embedding_dim = 256\n", - "\n", - "# number of RNN (here GRU) units\n", - "units = 1024\n", - "\n", - "# batch size \n", - "BATCH_SIZE = 64\n", - "\n", - "# buffer size to shuffle our dataset\n", - "BUFFER_SIZE = 10000" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "LFjSVAlWzf-N" - }, - "source": [ - "## Creating the input and output tensors\n", - "\n", - "Vectorizing the input and the target text because our model cannot understand strings only numbers.\n", - "\n", - "But first, we need to create the input and output vectors.\n", - "Remember the max_length we set above, we will use it here. We are creating **max_length** chunks of input, where each input vector is all the characters in that chunk except the last and the target vector is all the characters in that chunk except the first.\n", - "\n", - "For example, consider that the string = 'tensorflow' and the max_length is 9\n", - "\n", - "So, the `input = 'tensorflo'` and `output = 'ensorflow'`\n", - "\n", - "After creating the vectors, we convert each character into numbers using the **char2idx** dictionary we created above." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "0UHJDA39zf-O" - }, - "outputs": [], - "source": [ - "input_text = []\n", - "target_text = []\n", - "\n", - "for f in range(0, len(text)-max_length, max_length):\n", - " inps = text[f:f+max_length]\n", - " targ = text[f+1:f+1+max_length]\n", - "\n", - " input_text.append([char2idx[i] for i in inps])\n", - " target_text.append([char2idx[t] for t in targ])\n", - " \n", - "print (np.array(input_text).shape)\n", - "print (np.array(target_text).shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "MJdfPmdqzf-R" - }, - "source": [ - "## Creating batches and shuffling them using tf.data" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "p2pGotuNzf-S" - }, - "outputs": [], - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n", - "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "m8gPwEjRzf-Z" - }, - "source": [ - "## Creating the model\n", - "\n", - "We use the Model Subclassing API which gives us full flexibility to create the model and change it however we like. We use 3 layers to define our model.\n", - "\n", - "* Embedding layer\n", - "* GRU layer (you can use an LSTM layer here)\n", - "* Fully connected layer" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "P3KTiiInzf-a" - }, - "outputs": [], - "source": [ - "class Model(tf.keras.Model):\n", - " def __init__(self, vocab_size, embedding_dim, units, batch_size):\n", - " super(Model, self).__init__()\n", - " self.units = units\n", - " self.batch_sz = batch_size\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - "\n", - " if tf.test.is_gpu_available():\n", - " self.gru = tf.keras.layers.CuDNNGRU(self.units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " self.gru = tf.keras.layers.GRU(self.units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')\n", - "\n", - " self.fc = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " def call(self, x, hidden):\n", - " x = self.embedding(x)\n", - "\n", - " # output shape == (batch_size, max_length, hidden_size) \n", - " # states shape == (batch_size, hidden_size)\n", - "\n", - " # states variable to preserve the state of the model\n", - " # this will be used to pass at every step to the model while training\n", - " output, states = self.gru(x, initial_state=hidden)\n", - "\n", - "\n", - " # reshaping the output so that we can pass it to the Dense layer\n", - " # after reshaping the shape is (batch_size * max_length, hidden_size)\n", - " output = tf.reshape(output, (-1, output.shape[2]))\n", - "\n", - " # The dense layer will output predictions for every time_steps(max_length)\n", - " # output shape after the dense layer == (max_length * batch_size, vocab_size)\n", - " x = self.fc(output)\n", - "\n", - " return x, states" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "trpqTWyvk0nr" - }, - "source": [ - "## Call the model and set the optimizer and the loss function" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "7t2XrzEOzf-e" - }, - "outputs": [], - "source": [ - "model = Model(vocab_size, embedding_dim, units, BATCH_SIZE)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "dkjWIATszf-h" - }, - "outputs": [], - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "# using sparse_softmax_cross_entropy so that we don't have to create one-hot vectors\n", - "def loss_function(real, preds):\n", - " return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "3K6s6F79P7za" - }, - "source": [ - "## Checkpoints (Object-based saving)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "oAGisDdfP9rL" - }, - "outputs": [], - "source": [ - "checkpoint_dir = './training_checkpoints'\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", - "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", - " model=model)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "lPrP0XMUzf-p" - }, - "source": [ - "## Train the model\n", - "\n", - "Here we will use a custom training loop with the help of GradientTape()\n", - "\n", - "* We initialize the hidden state of the model with zeros and shape == (batch_size, number of rnn units). We do this by calling the function defined while creating the model.\n", - "\n", - "* Next, we iterate over the dataset(batch by batch) and calculate the **predictions and the hidden states** associated with that input.\n", - "\n", - "* There are a lot of interesting things happening here.\n", - " * The model gets hidden state(initialized with 0), lets call that **H0** and the first batch of input, lets call that **I0**.\n", - " * The model then returns the predictions **P1** and **H1**.\n", - " * For the next batch of input, the model receives **I1** and **H1**.\n", - " * The interesting thing here is that we pass **H1** to the model with **I1** which is how the model learns. The context learned from batch to batch is contained in the **hidden state**.\n", - " * We continue doing this until the dataset is exhausted and then we start a new epoch and repeat this.\n", - "\n", - "* After calculating the predictions, we calculate the **loss** using the loss function defined above. Then we calculate the gradients of the loss with respect to the model variables(input)\n", - "\n", - "* Finally, we take a step in that direction with the help of the optimizer using the apply_gradients function.\n", - "\n", - "Note:- If you are running this notebook in Colab which has a **Tesla K80 GPU** it takes about 23 seconds per epoch.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "d4tSNwymzf-q" - }, - "outputs": [], - "source": [ - "# Training step\n", - "\n", - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " \n", - " # initializing the hidden state at the start of every epoch\n", - " hidden = model.reset_states()\n", - " \n", - " for (batch, (inp, target)) in enumerate(dataset):\n", - " with tf.GradientTape() as tape:\n", - " # feeding the hidden state back into the model\n", - " # This is the interesting step\n", - " predictions, hidden = model(inp, hidden)\n", - " \n", - " # reshaping the target because that's how the \n", - " # loss function expects it\n", - " target = tf.reshape(target, (-1,))\n", - " loss = loss_function(target, predictions)\n", - " \n", - " grads = tape.gradient(loss, model.variables)\n", - " optimizer.apply_gradients(zip(grads, model.variables))\n", - "\n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,\n", - " batch,\n", - " loss))\n", - " # saving (checkpoint) the model every 5 epochs\n", - " if (epoch + 1) % 5 == 0:\n", - " checkpoint.save(file_prefix = checkpoint_prefix)\n", - "\n", - " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n", - " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "01AR9vpNQMFF" - }, - "source": [ - "## Restore the latest checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "tyvpYomYQQkF" - }, - "outputs": [], - "source": [ - "# restoring the latest checkpoint in checkpoint_dir\n", - "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "DjGz1tDkzf-u" - }, - "source": [ - "## Predicting using our trained model\n", - "\n", - "The below code block is used to generated the text\n", - "\n", - "* We start by choosing a start string and initializing the hidden state and setting the number of characters we want to generate.\n", - "\n", - "* We get predictions using the start_string and the hidden state\n", - "\n", - "* Then we use argmax to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n", - "\n", - "* **The hidden state returned by the model is fed back into the model so that it now has more context rather than just one word.** After we predict the next word, the modified hidden states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.\n", - "\n", - "* If you see the predictions, the model knows when to capitalize, make paragraphs and the text follows a shakespeare style of writing which is pretty awesome!" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "WvuwZBX5Ogfd" - }, - "outputs": [], - "source": [ - "# Evaluation step(generating text using the model learned)\n", - "\n", - "# number of characters to generate\n", - "num_generate = 1000\n", - "\n", - "# You can change the start string to experiment\n", - "start_string = 'Q'\n", - "# converting our start string to numbers(vectorizing!) \n", - "input_eval = [char2idx[s] for s in start_string]\n", - "input_eval = tf.expand_dims(input_eval, 0)\n", - "\n", - "# empty string to store our results\n", - "text_generated = ''\n", - "\n", - "# hidden state shape == (batch_size, number of rnn units); here batch size == 1\n", - "hidden = [tf.zeros((1, units))]\n", - "for i in range(num_generate):\n", - " predictions, hidden = model(input_eval, hidden)\n", - "\n", - " # using argmax to predict the word returned by the model\n", - " predicted_id = tf.argmax(predictions[-1]).numpy()\n", - " \n", - " # We pass the predicted word as the next input to the model\n", - " # along with the previous hidden state\n", - " input_eval = tf.expand_dims([predicted_id], 0)\n", - " \n", - " text_generated += idx2char[predicted_id]\n", - "\n", - "print (start_string + text_generated)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "AM2Uma_-yVIq" - }, - "source": [ - "## Next steps\n", - "\n", - "* Change the start string to a different character, or the start of a sentence.\n", - "* Experiment with training on a different, or with different parameters. [Project Gutenberg](http://www.gutenberg.org/ebooks/100), for example, contains a large collection of books.\n", - "* Add another RNN layer.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "gtEd86sX5cB2" - }, - "outputs": [], - "source": [ - "" + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/sequences/text_generation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] } ], diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD index 7bdf9053de749af9d09b12ba7b848e21c1fdb8f0..35d509904211d98f124d2555fc48166e75cb0dd9 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD +++ b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD @@ -28,7 +28,7 @@ py_library( cuda_py_test( name = "l2hmc_test", - size = "large", + size = "medium", srcs = ["l2hmc_test.py"], additional_deps = [ ":l2hmc", @@ -36,4 +36,8 @@ cuda_py_test( "//tensorflow/contrib/eager/python:tfe", "//third_party/py/numpy", ], + shard_count = 4, + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index 74ce9e84f013d79b3a33ffa79993980b561e366d..30afef83bc5c6c164c8456ed472f4d6064068a25 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -9,6 +9,13 @@ py_binary( name = "linear_regression", srcs = ["linear_regression.py"], srcs_version = "PY2AND3", + deps = [":linear_regression_lib"], +) + +py_library( + name = "linear_regression_lib", + srcs = ["linear_regression.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -20,10 +27,13 @@ cuda_py_test( size = "small", srcs = ["linear_regression_test.py"], additional_deps = [ - ":linear_regression", + ":linear_regression_lib", "//tensorflow:tensorflow_py", ], - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_windows", # TODO: needs investigation on Windows + "oss_serial", + ], ) cuda_py_test( @@ -31,7 +41,7 @@ cuda_py_test( size = "small", srcs = ["linear_regression_graph_test.py"], additional_deps = [ - ":linear_regression", + ":linear_regression_lib", "//tensorflow:tensorflow_py", ], ) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 099b712fc06d1d3eb9ab4095f8db7283690bda76..206ef9409df7b1dc21de42ba919d2ba97f334a8c 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -56,7 +56,7 @@ class LinearModel(tf.keras.Model): def mean_square_loss(model, xs, ys): - return tf.reduce_mean(tf.square(tf.subtract(model(xs), ys))) + return tf.reduce_mean(tf.squared_difference(model(xs), ys)) def fit(model, dataset, optimizer, verbose=False, logdir=None): diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py index 557ad42752144243ae3da61b955b31398cba846e..d412b25b368260b81256fd58034330b884261b2b 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py @@ -36,7 +36,7 @@ class GraphLinearRegressionBenchmark(tf.test.Benchmark): noise_level=0.01, batch_size=batch_size, num_batches=num_batches) - iterator = dataset.make_initializable_iterator() + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) x, y = iterator.get_next() model = linear_regression.LinearModel() diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 66d52a74943d0d81fde05ce51b019558b327978d..436e887736158ec1ba8e46eac8de4ac7b8e6be01 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -1,11 +1,28 @@ { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "nmt_with_attention.ipynb", + "version": "0.3.2", + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "accelerator": "GPU" + }, "cells": [ { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "AOpGoE2T-YXS" }, + "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors.\n", "\n", @@ -13,19 +30,19 @@ "\n", "# Neural Machine Translation with Attention\n", "\n", - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\n", - " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", - "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "CiwtNgENbx2g" }, + "cell_type": "markdown", "source": [ "This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n", "\n", @@ -33,24 +50,22 @@ "\n", "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n", "\n", - "\u003cimg src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\"\u003e\n", + "\"spanish-english\n", "\n", "Note: This example takes approximately 10 mintues to run on a single P100 GPU." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "tnxXKDjq3jEL" + "id": "tnxXKDjq3jEL", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "from __future__ import absolute_import, division, print_function\n", "\n", - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", + "# Import TensorFlow >= 1.10 and enable eager execution\n", "import tensorflow as tf\n", "\n", "tf.enable_eager_execution()\n", @@ -65,14 +80,16 @@ "import time\n", "\n", "print(tf.__version__)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "wfodePkj3jEa" }, + "cell_type": "markdown", "source": [ "## Download and prepare the dataset\n", "\n", @@ -91,14 +108,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "kRVATYOgJs1b" + "id": "kRVATYOgJs1b", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Download the file\n", "path_to_zip = tf.keras.utils.get_file(\n", @@ -106,17 +121,17 @@ " extract=True)\n", "\n", "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "rd0jw-eC3jEh" + "id": "rd0jw-eC3jEh", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Converts the unicode file to ascii\n", "def unicode_to_ascii(s):\n", @@ -128,7 +143,7 @@ " w = unicode_to_ascii(w.lower().strip())\n", " \n", " # creating a space between a word and the punctuation following it\n", - " # eg: \"he is a boy.\" =\u003e \"he is a boy .\" \n", + " # eg: \"he is a boy.\" => \"he is a boy .\" \n", " # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n", " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n", " w = re.sub(r'[\" \"]+', \" \", w)\n", @@ -140,19 +155,19 @@ " \n", " # adding a start and an end token to the sentence\n", " # so that the model know when to start and stop predicting.\n", - " w = '\u003cstart\u003e ' + w + ' \u003cend\u003e'\n", + " w = ' ' + w + ' '\n", " return w" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "OHn4Dct23jEm" + "id": "OHn4Dct23jEm", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# 1. Remove the accents\n", "# 2. Clean the sentences\n", @@ -163,20 +178,20 @@ " word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n", " \n", " return word_pairs" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "9xbqO7Iie9bb" + "id": "9xbqO7Iie9bb", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "# This class creates a word -\u003e index mapping (e.g,. \"dad\" -\u003e 5) and vice-versa \n", - "# (e.g., 5 -\u003e \"dad\") for each language,\n", + "# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n", + "# (e.g., 5 -> \"dad\") for each language,\n", "class LanguageIndex():\n", " def __init__(self, lang):\n", " self.lang = lang\n", @@ -192,23 +207,23 @@ " \n", " self.vocab = sorted(self.vocab)\n", " \n", - " self.word2idx['\u003cpad\u003e'] = 0\n", + " self.word2idx[''] = 0\n", " for index, word in enumerate(self.vocab):\n", " self.word2idx[word] = index + 1\n", " \n", " for word, index in self.word2idx.items():\n", " self.idx2word[index] = word" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "eAY9k49G3jE_" + "id": "eAY9k49G3jE_", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def max_length(tensor):\n", " return max(len(t) for t in tensor)\n", @@ -244,71 +259,71 @@ " padding='post')\n", " \n", " return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "GOi42V79Ydlr" }, + "cell_type": "markdown", "source": [ "### Limit the size of the dataset to experiment faster (optional)\n", "\n", - "Training on the complete dataset of \u003e100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" + "Training on the complete dataset of >100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "cnxC7q-j3jFD" + "id": "cnxC7q-j3jFD", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Try experimenting with the size of that dataset\n", "num_examples = 30000\n", "input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "4QILQkOs3jFG" + "id": "4QILQkOs3jFG", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Creating training and validation sets using an 80-20 split\n", "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n", "\n", "# Show length\n", "len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rgCLkfv5uO3d" }, + "cell_type": "markdown", "source": [ "### Create a tf.data dataset" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "TqHsArVZ3jFS" + "id": "TqHsArVZ3jFS", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "BUFFER_SIZE = len(input_tensor_train)\n", "BATCH_SIZE = 64\n", @@ -320,27 +335,29 @@ "\n", "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "TNfHIF71ulLu" }, + "cell_type": "markdown", "source": [ "## Write the encoder and decoder model\n", "\n", - "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", + "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://github.com/tensorflow/nmt). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://github.com/tensorflow/nmt#background-on-the-attention-mechanism) from the seq2seq tutorial. The following diagram shows that each input word is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", "\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\"\u003e\n", + "\"attention\n", "\n", "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n", "\n", "Here are the equations that are implemented:\n", "\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\"\u003e\n", - "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\"\u003e\n", + "\"attention\n", + "\"attention\n", "\n", "We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n", "\n", @@ -362,14 +379,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "avyJ_4VIUoHb" + "id": "avyJ_4VIUoHb", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def gru(units):\n", " # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n", @@ -385,17 +400,17 @@ " return_state=True, \n", " recurrent_activation='sigmoid', \n", " recurrent_initializer='glorot_uniform')" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "nZ2rI24i3jFg" + "id": "nZ2rI24i3jFg", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "class Encoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n", @@ -412,17 +427,17 @@ " \n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.enc_units))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "yJ_B3mhW3jFk" + "id": "yJ_B3mhW3jFk", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "class Decoder(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", @@ -476,41 +491,41 @@ " \n", " def initialize_hidden_state(self):\n", " return tf.zeros((self.batch_sz, self.dec_units))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "P5UY8wko3jFp" + "id": "P5UY8wko3jFp", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_ch_71VbIRfK" }, + "cell_type": "markdown", "source": [ "## Define the optimizer and the loss function" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "WmTHr5iV3jFr" + "id": "WmTHr5iV3jFr", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "optimizer = tf.train.AdamOptimizer()\n", "\n", @@ -519,41 +534,43 @@ " mask = 1 - np.equal(real, 0)\n", " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", " return tf.reduce_mean(loss_)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DMVWzzsfNl4e" }, + "cell_type": "markdown", "source": [ "## Checkpoints (Object-based saving)" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "Zj8bXQTgNwrF" + "id": "Zj8bXQTgNwrF", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "checkpoint_dir = './training_checkpoints'\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", " encoder=encoder,\n", " decoder=decoder)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "hpObfY22IddU" }, + "cell_type": "markdown", "source": [ "## Training\n", "\n", @@ -567,14 +584,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "ddefjBMa3jF0" + "id": "ddefjBMa3jF0", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "EPOCHS = 10\n", "\n", @@ -592,7 +607,7 @@ " \n", " dec_hidden = enc_hidden\n", " \n", - " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']] * BATCH_SIZE, 1) \n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']] * BATCH_SIZE, 1) \n", " \n", " # Teacher forcing - feeding the target as the next input\n", " for t in range(1, targ.shape[1]):\n", @@ -625,14 +640,16 @@ " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", " total_loss / N_BATCH))\n", " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "mU3Ce8M6I3rz" }, + "cell_type": "markdown", "source": [ "## Translate\n", "\n", @@ -644,14 +661,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "EbQpyYs13jF_" + "id": "EbQpyYs13jF_", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", @@ -668,7 +683,7 @@ " enc_out, enc_hidden = encoder(inputs, hidden)\n", "\n", " dec_hidden = enc_hidden\n", - " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']], 0)\n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']], 0)\n", "\n", " for t in range(max_length_targ):\n", " predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n", @@ -681,24 +696,24 @@ "\n", " result += targ_lang.idx2word[predicted_id] + ' '\n", "\n", - " if targ_lang.idx2word[predicted_id] == '\u003cend\u003e':\n", + " if targ_lang.idx2word[predicted_id] == '':\n", " return result, sentence, attention_plot\n", " \n", " # the predicted ID is fed back into the model\n", " dec_input = tf.expand_dims([predicted_id], 0)\n", "\n", " return result, sentence, attention_plot" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "s5hQWlbN3jGF" + "id": "s5hQWlbN3jGF", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# function for plotting the attention weights\n", "def plot_attention(attention, sentence, predicted_sentence):\n", @@ -712,17 +727,17 @@ " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n", "\n", " plt.show()" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "sl9zUHzg3jGI" + "id": "sl9zUHzg3jGI", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", " result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n", @@ -732,91 +747,93 @@ " \n", " attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n", " plot_attention(attention_plot, sentence.split(' '), result.split(' '))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "n250XbnjOaqP" }, + "cell_type": "markdown", "source": [ "## Restore the latest checkpoint and test" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "UJpT9D5_OgP6" + "id": "UJpT9D5_OgP6", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# restoring the latest checkpoint in checkpoint_dir\n", "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "WrAM0FDomq3E" + "id": "WrAM0FDomq3E", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "zSx2iM36EZQZ" + "id": "zSx2iM36EZQZ", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "A3LLCx3ZE0Ls" + "id": "A3LLCx3ZE0Ls", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "translate(u'todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "DUQVLVqUE1YW" + "id": "DUQVLVqUE1YW", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# wrong translation\n", "translate(u'trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RTe5P5ioMJwN" }, + "cell_type": "markdown", "source": [ "## Next steps\n", "\n", @@ -824,31 +841,5 @@ "* Experiment with training on a larger dataset, or using more epochs\n" ] } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "nmt_with_attention.ipynb", - "private_outputs": true, - "provenance": [ - { - "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U", - "timestamp": 1527858391290 - }, - { - "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv", - "timestamp": 1527776041613 - } - ], - "toc_visible": true, - "version": "0.3.2" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index f3135a9668fc0dc7faa93a5f119b53f3efd34c6e..f2851d97223e483da11120f1fe3f0a2f641dfb81 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -27,7 +27,7 @@ py_library( cuda_py_test( name = "resnet50_test", - size = "large", + size = "medium", srcs = ["resnet50_test.py"], additional_deps = [ ":resnet50", @@ -35,17 +35,19 @@ cuda_py_test( "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "noasan", # Fix b/118130911 "nomsan", # Fix b/118130911 "notsan", # Fix b/118130911 "optonly", + "oss_serial", ], ) cuda_py_test( name = "resnet50_graph_test", - size = "large", + size = "medium", srcs = ["resnet50_graph_test.py"], additional_deps = [ ":resnet50", @@ -53,10 +55,12 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "noasan", "nomsan", "notsan", "optonly", + "oss_serial", ], ) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index f3bb978875e226f58d6a00e09154191673a97415..fb7975d8fe867711cff31d627788a2d62a520aa9 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -142,7 +142,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): with tf.Graph().as_default(): np_images, np_labels = random_batch(batch_size) dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() - (images, labels) = dataset.make_one_shot_iterator().get_next() + images, labels = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() model = resnet50.ResNet50(data_format()) logits = model(images, training=True) diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD index 4f0d46b1bae3760a63b2abe871034bdedf258f07..cb207b8ddf3641a68a114386f6a95a26ce2b74d6 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/BUILD +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -67,30 +67,36 @@ py_library( # Tests cuda_py_test( name = "ops_test", - size = "large", + size = "medium", srcs = ["ops_test.py"], additional_deps = [ ":ops", "//tensorflow:tensorflow_py", ], + shard_count = 4, + tags = [ + "oss_serial", + ], ) cuda_py_test( name = "blocks_test", - size = "large", + size = "medium", srcs = ["blocks_test.py"], additional_deps = [ ":blocks", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ + "no_oss", # b/123045964 "optonly", ], ) cuda_py_test( name = "revnet_test", - size = "large", + size = "medium", srcs = ["revnet_test.py"], additional_deps = [ ":blocks_test", @@ -98,9 +104,11 @@ cuda_py_test( ":revnet", "//tensorflow:tensorflow_py", ], + shard_count = 4, tags = [ "no_pip", # depends on blocks_test, which is not available in pip package "optonly", + "oss_serial", ], ) @@ -127,6 +135,13 @@ py_binary( name = "main", srcs = ["main.py"], srcs_version = "PY2AND3", + deps = [":main_lib"], +) + +py_library( + name = "main_lib", + srcs = ["main.py"], + srcs_version = "PY2AND3", deps = [ ":cifar_input", ":config", @@ -141,7 +156,7 @@ py_binary( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], @@ -153,7 +168,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], @@ -165,7 +180,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":cifar_input", - ":main", + ":main_lib", ":revnet", "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index b702e91f92220c2a9003a1b82411131332012a9e..9585f3565f83af724b6336e466d3671443ba2361 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -72,14 +72,11 @@ def main(_): train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step.numpy() % config.log_every == 0: - it_test = ds_test.make_one_shot_iterator() - acc_test, loss_test = evaluate(model, it_test) + acc_test, loss_test = evaluate(model, ds_test) if FLAGS.validate: - it_train = ds_train_one_shot.make_one_shot_iterator() - it_validation = ds_validation.make_one_shot_iterator() - acc_train, loss_train = evaluate(model, it_train) - acc_validation, loss_validation = evaluate(model, it_validation) + acc_train, loss_train = evaluate(model, ds_train_one_shot) + acc_validation, loss_validation = evaluate(model, ds_validation) print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "validation set accuracy {:.4f}, loss {:.4f}; " @@ -218,11 +215,11 @@ def train_one_iter(model, inputs, labels, optimizer, global_step=None): return logits, loss -def evaluate(model, iterator): +def evaluate(model, dataset): """Compute accuracy with the given dataset iterator.""" mean_loss = tfe.metrics.Mean() accuracy = tfe.metrics.Accuracy() - for x, y in iterator: + for x, y in dataset: logits, _ = model(x, training=False) loss = model.compute_loss(logits=logits, labels=y) accuracy( diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index 1f2cb14972f0b92d29489adff8f94e790e1ec4ed..7406787ba438345dc485c50e347e40597b2037f5 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -96,6 +96,7 @@ class RevNet(tf.keras.Model): def call(self, inputs, training=True): """Forward pass.""" + saved_hidden = None if training: saved_hidden = [inputs] diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD index d500b632ebb97fd12ded3a215b0f1a686194874f..f4dbe7ac16f734f7bee045bc71e9559b630adf81 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD @@ -9,6 +9,13 @@ py_binary( name = "rnn_colorbot", srcs = ["rnn_colorbot.py"], srcs_version = "PY2AND3", + deps = [":rnn_colorbot_lib"], +) + +py_library( + name = "rnn_colorbot_lib", + srcs = ["rnn_colorbot.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", @@ -21,8 +28,11 @@ cuda_py_test( name = "rnn_colorbot_test", srcs = ["rnn_colorbot_test.py"], additional_deps = [ - ":rnn_colorbot", + ":rnn_colorbot_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 74ebb1ec77131a560b1ebfd062c690920c35e261..1c718a5ce3d8e1541656d92fd5e8dad6c6683c4c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -207,7 +207,7 @@ class RNNColorbot(tf.keras.Model): def loss(labels, predictions): """Computes mean squared loss.""" - return tf.reduce_mean(tf.square(predictions - labels)) + return tf.reduce_mean(tf.squared_difference(predictions, labels)) def test(model, eval_data): diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD index 2cc2fcbfeb21ee6218d7912d9a93ea2f7b2ea226..43a6ca526d3a0aecda2c8df865a0487ac28758ab 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD @@ -9,6 +9,13 @@ py_binary( name = "rnn_ptb", srcs = ["rnn_ptb.py"], srcs_version = "PY2AND3", + deps = [":rnn_ptb_lib"], +) + +py_library( + name = "rnn_ptb_lib", + srcs = ["rnn_ptb.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", @@ -21,18 +28,22 @@ cuda_py_test( name = "rnn_ptb_test", srcs = ["rnn_ptb_test.py"], additional_deps = [ - ":rnn_ptb", + ":rnn_ptb_lib", "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + tags = ["no_oss"], # b/123045964 ) cuda_py_test( name = "rnn_ptb_graph_test", srcs = ["rnn_ptb_graph_test.py"], additional_deps = [ - ":rnn_ptb", + ":rnn_ptb_lib", "//third_party/py/numpy", "//tensorflow:tensorflow_py", ], + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 15776c694e92825895437a4c1547699f6d9269fb..9b5a2c947b153308c83f1a922d06c034ec5f9ddf 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -128,7 +128,7 @@ class PTBModel(tf.keras.Model): self.linear = layers.Dense( vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1)) - self._output_shape = [-1, embedding_dim] + self._output_shape = [-1, hidden_dim] def call(self, input_seq, training): """Run the forward pass of PTBModel. diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py index 63b5c4c54d13e9c2448ec1f572ca1389f2443bef..770484abed96e540cf75cc5368a1410c31a8d2d0 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py @@ -82,7 +82,7 @@ class PTBBenchmark(tf.test.Benchmark): tf.ones( [PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], dtype=tf.int64)).repeat(num_iters + num_warmup) - inputs = dataset.make_one_shot_iterator().get_next() + inputs = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() with tf.device(tf.test.gpu_device_name()): outputs = model(inputs, training=True) @@ -124,7 +124,8 @@ class PTBBenchmark(tf.test.Benchmark): dtype=tf.int64)).repeat(num_iters + num_warmup) # inputs and labels have the same shape dataset = tf.data.Dataset.zip((dataset, dataset)) - (inputs, labels) = dataset.make_one_shot_iterator().get_next() + (inputs, labels) = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() with tf.device(tf.test.gpu_device_name()): optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index 5966f1d4873e8e77b3ad5914da7bfc7e69d4e341..9b0fbaa6793e28d327745767e6ccd3085211ff7d 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -42,5 +42,6 @@ cuda_py_test( "no-internal-py3", # flaky "no_cuda_on_cpu_tap", "no_pip", # because spinn.py is under third_party/. + "oss_serial", ], ) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 566246de4957c1dc5919c10e22146706f9e50be8..c8d9266672a8b87d32338ea7c4f74fb40d41c767 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -37,7 +37,7 @@ from tensorflow.python.training.checkpointable import base as checkpointable _to_replace = re.compile("[^A-Za-z0-9.]") -class Metric(checkpointable.CheckpointableBase): +class Metric(checkpointable.Checkpointable): """A metric holds state for aggregating statistics over an evaluation run. Example use with eager execution: diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index 3926de15e71c9917f88fc3f58740b8c75354ab26..f540d9b37b69c7be3b0662b07bd6e9cb8220fadc 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -24,12 +24,12 @@ import os import numpy as np from tensorflow.contrib.eager.python import parameter_server -from tensorflow.contrib.eager.python import remote from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function +from tensorflow.python.eager import remote from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index f9c716360c5755ee1902b576545d776725f9966f..1d0d6c6c14ce4a8e454206e0be9fea4724f09192 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -115,6 +115,11 @@ def restore_variables_on_create(save_path, map_func=None): class Saver(object): """A tf.train.Saver adapter for use when eager execution is enabled. + + `Saver`'s name-based checkpointing strategy is fragile. Please switch to + `tf.train.Checkpoint` or `tf.keras.Model.save_weights`, which perform a more + robust object-based saving. These APIs will load checkpoints written by + `Saver`. """ def __init__(self, var_list): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 33c988fd9065e7fbe7b9aeb85cad82eb3c119f76..b82e1bb71bce9a28d7bbbf961cc6d5e25dd18acf 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -41,6 +41,8 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@add_execution_callback @@clear_execution_callbacks +@@errstate +@@ExecutionCallback @@inf_callback @@inf_nan_callback @@nan_callback @@ -97,7 +99,6 @@ from tensorflow.contrib.eager.python.network import Network from tensorflow.contrib.eager.python.network import Sequential from tensorflow.contrib.eager.python.network import save_network_checkpoint from tensorflow.contrib.eager.python.network import restore_network_checkpoint -from tensorflow.contrib.eager.python.remote import connect_to_remote_host from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver @@ -119,10 +120,13 @@ from tensorflow.python.eager.context import set_server_def from tensorflow.python.eager.def_function import function from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks +from tensorflow.python.eager.execution_callbacks import errstate +from tensorflow.python.eager.execution_callbacks import ExecutionCallback from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr +from tensorflow.python.eager.remote import connect_to_remote_host from tensorflow.python.framework.tensor_spec import TensorSpec from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import enable_eager_execution_internal as enable_remote_eager_execution @@ -134,7 +138,7 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable.tracking import Checkpointable +from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable from tensorflow.python.training.checkpointable.util import CheckpointableSaver from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 8c35dddb5a515aa09cc70c173a9f0605e8567e82..6881fabdc09e3275c29f3013283999c96e283770 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import tempfile from tensorflow.contrib.eager.python import tfe +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -40,6 +41,9 @@ class TFETest(test_util.TensorFlowTestCase): self.assertAllEqual([[4.]], y.numpy()) def testInstantError(self): + if context.num_gpus(): + # TODO(nareshmodi): make this test better + self.skipTest("Gather doesn't do index checking on GPUs") with self.assertRaisesRegexp(errors.InvalidArgumentError, r'indices = 7 is not in \[0, 3\)'): array_ops.gather([0, 1, 2], 7) diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index e344d7a23b55134612aab430b50cf065bd1095e4..48a6ef4dca0ca7682f7b99b66177679f29ad9ec9 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -28,7 +28,6 @@ tf_custom_op_py_library( "python/ops/wals.py", ], dso = [ - ":python/ops/_clustering_ops.so", ":python/ops/_factorization_ops.so", ], kernels = [ @@ -38,12 +37,12 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":factorization_ops_test_utils_py", - ":gen_clustering_ops", ":gen_factorization_ops", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", + "//tensorflow/python:clustering_ops_gen", "//tensorflow/python:control_flow_ops", "//tensorflow/python:data_flow_ops", "//tensorflow/python:embedding_ops", @@ -77,17 +76,6 @@ py_library( ], ) -# Ops -tf_custom_op_library( - name = "python/ops/_clustering_ops.so", - srcs = [ - "ops/clustering_ops.cc", - ], - deps = [ - "//tensorflow/contrib/factorization/kernels:clustering_ops", - ], -) - tf_custom_op_library( name = "python/ops/_factorization_ops.so", srcs = [ @@ -100,26 +88,16 @@ tf_custom_op_library( ) tf_gen_op_libs([ - "clustering_ops", "factorization_ops", ]) cc_library( name = "all_ops", deps = [ - ":clustering_ops_op_lib", ":factorization_ops_op_lib", ], ) -tf_gen_op_wrapper_py( - name = "gen_clustering_ops", - out = "python/ops/gen_clustering_ops.py", - deps = [ - ":clustering_ops_op_lib", - ], -) - tf_gen_op_wrapper_py( name = "gen_factorization_ops", out = "python/ops/gen_factorization_ops.py", @@ -131,7 +109,7 @@ tf_gen_op_wrapper_py( # Ops tests tf_py_test( name = "gmm_test", - size = "large", + size = "medium", srcs = [ "python/ops/gmm_test.py", ], @@ -152,6 +130,7 @@ tf_py_test( "//tensorflow/python:random_seed", "//tensorflow/python:training", ], + shard_count = 4, tags = [ "no_pip", # b/38283730 "notsan", # Flaky: b/30756419 @@ -249,7 +228,7 @@ py_test( tf_py_test( name = "wals_test", - size = "large", + size = "medium", srcs = ["python/ops/wals_test.py"], additional_deps = [ ":factorization_py", @@ -272,8 +251,8 @@ tf_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], + shard_count = 4, tags = [ - "manual", "noasan", # times out b/63678675 "nomsan", ], diff --git a/tensorflow/contrib/factorization/kernels/BUILD b/tensorflow/contrib/factorization/kernels/BUILD index ea8b9a17a27093cb57564861815edd6ecb18a014..23d7e088d067effa446e4bcdc9609db612066568 100644 --- a/tensorflow/contrib/factorization/kernels/BUILD +++ b/tensorflow/contrib/factorization/kernels/BUILD @@ -11,7 +11,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") cc_library( name = "all_kernels", deps = [ - ":clustering_ops", ":masked_matmul_ops", ":wals_solver_ops", "@protobuf_archive//:protobuf_headers", @@ -29,17 +28,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "clustering_ops", - srcs = ["clustering_ops.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - cc_library( name = "masked_matmul_ops", srcs = ["masked_matmul_ops.cc"], @@ -51,19 +39,3 @@ cc_library( ], alwayslink = 1, ) - -tf_cc_test( - name = "clustering_ops_test", - srcs = ["clustering_ops_test.cc"], - deps = [ - ":clustering_ops", - "//tensorflow/contrib/factorization:clustering_ops_op_lib", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) diff --git a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc index a8c5d0763c28ba2b54f217405f0da65533f26b91..68078ba8bbb07b4344c19d554012d214229f9c4f 100644 --- a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc +++ b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc @@ -19,12 +19,12 @@ #include #include +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" diff --git a/tensorflow/contrib/factorization/ops/clustering_ops.cc b/tensorflow/contrib/factorization/ops/clustering_ops.cc deleted file mode 100644 index 2686702c1d5768f661dac610c96089eb02e360d7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/factorization/ops/clustering_ops.cc +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2016 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not -// use this file except in compliance with the License. You may obtain a copy -// of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations under -// the License. -// ============================================================================== - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("KmeansPlusPlusInitialization") - .Input("points: float32") - .Input("num_to_sample: int64") - .Input("seed: int64") - .Input("num_retries_per_sample: int64") - .Output("samples: float32") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"( -Selects num_to_sample rows of input using the KMeans++ criterion. - -Rows of points are assumed to be input points. One row is selected at random. -Subsequent rows are sampled with probability proportional to the squared L2 -distance from the nearest row selected thus far till num_to_sample rows have -been sampled. - -points: Matrix of shape (n, d). Rows are assumed to be input points. -num_to_sample: Scalar. The number of rows to sample. This value must not be - larger than n. -seed: Scalar. Seed for initializing the random number generator. -num_retries_per_sample: Scalar. For each row that is sampled, this parameter - specifies the number of additional points to draw from the current - distribution before selecting the best. If a negative value is specified, a - heuristic is used to sample O(log(num_to_sample)) additional points. -samples: Matrix of shape (num_to_sample, d). The sampled rows. -)"); - -REGISTER_OP("KMC2ChainInitialization") - .Input("distances: float32") - .Input("seed: int64") - .Output("index: int64") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"( -Returns the index of a data point that should be added to the seed set. - -Entries in distances are assumed to be squared distances of candidate points to -the already sampled centers in the seed set. The op constructs one Markov chain -of the k-MC^2 algorithm and returns the index of one candidate point to be added -as an additional cluster center. - -distances: Vector with squared distances to the closest previously sampled - cluster center for each candidate point. -seed: Scalar. Seed for initializing the random number generator. -index: Scalar with the index of the sampled point. -)"); - -REGISTER_OP("NearestNeighbors") - .Input("points: float32") - .Input("centers: float32") - .Input("k: int64") - .Output("nearest_center_indices: int64") - .Output("nearest_center_distances: float32") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"( -Selects the k nearest centers for each point. - -Rows of points are assumed to be input points. Rows of centers are assumed to be -the list of candidate centers. For each point, the k centers that have least L2 -distance to it are computed. - -points: Matrix of shape (n, d). Rows are assumed to be input points. -centers: Matrix of shape (m, d). Rows are assumed to be centers. -k: Scalar. Number of nearest centers to return for each point. If k is larger - than m, then only m centers are returned. -nearest_center_indices: Matrix of shape (n, min(m, k)). Each row contains the - indices of the centers closest to the corresponding point, ordered by - increasing distance. -nearest_center_distances: Matrix of shape (n, min(m, k)). Each row contains the - squared L2 distance to the corresponding center in nearest_center_indices. -)"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index 84e80791f4991ad2b67d0a00ee1e00cf0d0daadc..d48b89cbacce34781819010addbcbd0ba66f9873 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -18,28 +18,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.factorization.python.ops import gen_clustering_ops -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.factorization.python.ops.gen_clustering_ops import * -# pylint: enable=wildcard-import -from tensorflow.contrib.util import loader from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_clustering_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.embedding_ops import embedding_lookup -from tensorflow.python.platform import resource_loader - -_clustering_ops = loader.load_op_library( - resource_loader.get_path_to_datafile('_clustering_ops.so')) +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_clustering_ops import * +# pylint: enable=wildcard-import # Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\) # which is the square root of the sum of the absolute squares of the elements diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index d365ad111760247fc18b730657390f07ba6b865e..9f0664dfe5ba7a098b6976388d1cf737dafb4842 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -314,8 +314,7 @@ class GmmAlgorithm(object): # reparametrization of variance parameters. det_expanded = math_ops.reduce_sum( math_ops.log(self._covs + 1e-3), 1, keepdims=True) - diff = shard - self._means - x2 = math_ops.square(diff) + x2 = math_ops.squared_difference(shard, self._means) cov_expanded = array_ops.expand_dims(1.0 / (self._covs + 1e-3), 2) # num_classes X num_examples x2_cov = math_ops.matmul(x2, cov_expanded) diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 1cd83bdb5de7c2f6dc91c980750b49aca1a7790b..8fc5f1cfe7800653ef1e43c6d40d1a66e34f2106 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -6,7 +6,7 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") py_library( name = "feature_column_py", @@ -37,13 +37,13 @@ py_library( ], ) -py_test( +tf_py_test( name = "sequence_feature_column_test", srcs = ["python/feature_column/sequence_feature_column_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":sequence_feature_column", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", @@ -53,17 +53,14 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python/feature_column:feature_column_py", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], + tags = ["no_pip"], ) -py_test( +tf_py_test( name = "sequence_feature_column_integration_test", srcs = ["python/feature_column/sequence_feature_column_integration_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":sequence_feature_column", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", @@ -73,6 +70,7 @@ py_test( "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras:layers", ], + tags = ["no_pip"], ) py_library( @@ -94,14 +92,13 @@ py_library( ], ) -py_test( +tf_py_test( name = "sequence_feature_column_v2_test", srcs = ["python/feature_column/sequence_feature_column_v2_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":sequence_feature_column", + additional_deps = [ ":sequence_feature_column_v2", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", @@ -110,9 +107,25 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", - "//tensorflow/python/feature_column", "//tensorflow/python/feature_column:feature_column_py", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", + "//tensorflow/python/feature_column:feature_column_v2_test", + ], + tags = ["no_pip"], +) + +py_test( + name = "sequence_feature_column_v2_integration_test", + srcs = ["python/feature_column/sequence_feature_column_v2_integration_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequence_feature_column_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/feature_column:feature_column_py", + "//tensorflow/python/keras:layers", ], ) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py index 0d34ad161855476b6a4cd9a258521dbe122b4140..2f4bda194a41242167e0abfcaeac5044f6026f85 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py @@ -27,6 +27,7 @@ import collections from tensorflow.python.feature_column import feature_column as fc_old from tensorflow.python.feature_column import feature_column_lib as fc +from tensorflow.python.feature_column import feature_column_v2 as fc_v2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -34,107 +35,115 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops import variable_scope # pylint: disable=protected-access -def sequence_input_layer( - features, - feature_columns, - weight_collections=None, - trainable=True): - """"Builds input layer for sequence input. +class SequenceFeatures(fc_v2._BaseFeaturesLayer): + """A layer for sequence input. - All `feature_columns` must be sequence dense columns with the same - `sequence_length`. The output of this method can be fed into sequence - networks, such as RNN. + All `feature_columns` must be sequence dense columns with the same + `sequence_length`. The output of this method can be fed into sequence + networks, such as RNN. - The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. - `T` is the maximum sequence length for this batch, which could differ from - batch to batch. + The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ from + batch to batch. - If multiple `feature_columns` are given with `Di` `num_elements` each, their - outputs are concatenated. So, the final `Tensor` has shape - `[batch_size, T, D0 + D1 + ... + Dn]`. + If multiple `feature_columns` are given with `Di` `num_elements` each, their + outputs are concatenated. So, the final `Tensor` has shape + `[batch_size, T, D0 + D1 + ... + Dn]`. - Example: + Example: - ```python - rating = sequence_numeric_column('rating') - watches = sequence_categorical_column_with_identity( - 'watches', num_buckets=1000) - watches_embedding = embedding_column(watches, dimension=10) - columns = [rating, watches] + ```python + rating = sequence_numeric_column('rating') + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [rating, watches] - features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + sequence_input_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_input_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) - ``` + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) + ``` + """ - Args: - features: A dict mapping keys to tensors. - feature_columns: An iterable of dense sequence columns. Valid columns are - - `embedding_column` that wraps a `sequence_categorical_column_with_*` - - `sequence_numeric_column`. - weight_collections: A list of collection names to which the Variable will be - added. Note that variables will also be added to collections - `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. - trainable: If `True` also add the variable to the graph collection - `GraphKeys.TRAINABLE_VARIABLES`. + def __init__( + self, + feature_columns, + trainable=True, + name=None, + **kwargs): + """"Constructs a SequenceFeatures layer. - Returns: - An `(input_layer, sequence_length)` tuple where: - - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. - `T` is the maximum sequence length for this batch, which could differ - from batch to batch. `D` is the sum of `num_elements` for all - `feature_columns`. - - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence - length for each example. + Args: + feature_columns: An iterable of dense sequence columns. Valid columns are + - `embedding_column` that wraps a `sequence_categorical_column_with_*` + - `sequence_numeric_column`. + trainable: Boolean, whether the layer's variables will be updated via + gradient descent during training. + name: Name to give to the SequenceFeatures. + **kwargs: Keyword arguments to construct a layer. + + Raises: + ValueError: If any of the `feature_columns` is not a + `SequenceDenseColumn`. + """ + super(SequenceFeatures, self).__init__( + feature_columns=feature_columns, + trainable=trainable, + name=name, + expected_column_type=fc_v2.SequenceDenseColumn, + **kwargs) - Raises: - ValueError: If any of the `feature_columns` is the wrong type. - """ - feature_columns = fc_old._normalize_feature_columns(feature_columns) - for c in feature_columns: - if not isinstance(c, fc_old._SequenceDenseColumn): - raise ValueError( - 'All feature_columns must be of type _SequenceDenseColumn. ' - 'You can wrap a sequence_categorical_column with an embedding_column ' - 'or indicator_column. ' - 'Given (type {}): {}'.format(type(c), c)) - - with variable_scope.variable_scope( - None, default_name='sequence_input_layer', values=features.values()): - builder = fc_old._LazyBuilder(features) + def _target_shape(self, input_shape, total_elements): + return (input_shape[0], input_shape[1], total_elements) + + def call(self, features): + """Returns sequence input corresponding to the `feature_columns`. + + Args: + features: A dict mapping keys to tensors. + + Returns: + An `(input_layer, sequence_length)` tuple where: + - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ + from batch to batch. `D` is the sum of `num_elements` for all + `feature_columns`. + - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence + length for each example. + + Raises: + ValueError: If features are not a dictionary. + """ + if not isinstance(features, dict): + raise ValueError('We expected a dictionary here. Instead we got: ', + features) + transformation_cache = fc.FeatureTransformationCache(features) output_tensors = [] sequence_lengths = [] - ordered_columns = [] - - for column in sorted(feature_columns, key=lambda x: x.name): - ordered_columns.append(column) - with variable_scope.variable_scope( - None, default_name=column._var_scope_name): - dense_tensor, sequence_length = column._get_sequence_dense_tensor( - builder, - weight_collections=weight_collections, - trainable=trainable) + + for column in self._feature_columns: + with ops.name_scope(column.name): + dense_tensor, sequence_length = column.get_sequence_dense_tensor( + transformation_cache, self._state_manager) # Flattens the final dimension to produce a 3D Tensor. - num_elements = column._variable_shape.num_elements() - shape = array_ops.shape(dense_tensor) - target_shape = [shape[0], shape[1], num_elements] - output_tensors.append( - array_ops.reshape(dense_tensor, shape=target_shape)) + output_tensors.append(self._process_dense_tensor(column, dense_tensor)) sequence_lengths.append(sequence_length) - fc_old._verify_static_batch_size_equality(output_tensors, ordered_columns) - fc_old._verify_static_batch_size_equality(sequence_lengths, ordered_columns) + # Check and process sequence lengths. + fc_v2._verify_static_batch_size_equality(sequence_lengths, + self._feature_columns) sequence_length = _assert_all_equal_and_return(sequence_lengths) - return array_ops.concat(output_tensors, -1), sequence_length + return self._verify_and_concat_tensors(output_tensors), sequence_length def concatenate_context_input(context_input, sequence_input): @@ -203,11 +212,13 @@ def sequence_categorical_column_with_identity( columns = [watches_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_feature_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` Args: @@ -219,15 +230,17 @@ def sequence_categorical_column_with_identity( `[0, num_buckets)`, and will replace out-of-range inputs. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: if `num_buckets` is less than one. ValueError: if `default_value` is not in range `[0, num_buckets)`. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_identity( - key=key, num_buckets=num_buckets, default_value=default_value)) + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_identity( + key=key, + num_buckets=num_buckets, + default_value=default_value)) def sequence_categorical_column_with_hash_bucket( @@ -247,11 +260,13 @@ def sequence_categorical_column_with_hash_bucket( columns = [tokens_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_feature_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` Args: @@ -260,15 +275,17 @@ def sequence_categorical_column_with_hash_bucket( dtype: The type of features. Only string and integer types are supported. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: `hash_bucket_size` is not greater than 1. ValueError: `dtype` is neither string nor integer. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_hash_bucket( - key=key, hash_bucket_size=hash_bucket_size, dtype=dtype)) + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_hash_bucket( + key=key, + hash_bucket_size=hash_bucket_size, + dtype=dtype)) def sequence_categorical_column_with_vocabulary_file( @@ -290,11 +307,13 @@ def sequence_categorical_column_with_vocabulary_file( columns = [states_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_feature_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` Args: @@ -314,7 +333,7 @@ def sequence_categorical_column_with_vocabulary_file( dtype: The type of features. Only string and integer types are supported. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: `vocabulary_file` is missing or cannot be opened. @@ -323,8 +342,8 @@ def sequence_categorical_column_with_vocabulary_file( ValueError: `num_oov_buckets` and `default_value` are both specified. ValueError: `dtype` is neither string nor integer. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_vocabulary_file( + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_file( key=key, vocabulary_file=vocabulary_file, vocabulary_size=vocabulary_size, @@ -351,11 +370,13 @@ def sequence_categorical_column_with_vocabulary_list( columns = [colors_embedding] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - input_layer, sequence_length = sequence_input_layer(features, columns) + sequence_feature_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_feature_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` Args: @@ -375,7 +396,7 @@ def sequence_categorical_column_with_vocabulary_list( with `default_value`. Returns: - A `_SequenceCategoricalColumn`. + A `SequenceCategoricalColumn`. Raises: ValueError: if `vocabulary_list` is empty, or contains duplicate keys. @@ -383,8 +404,8 @@ def sequence_categorical_column_with_vocabulary_list( ValueError: `num_oov_buckets` and `default_value` are both specified. ValueError: if `dtype` is not integer or string. """ - return fc_old._SequenceCategoricalColumn( - fc_old._categorical_column_with_vocabulary_list( + return fc.SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_list( key=key, vocabulary_list=vocabulary_list, dtype=dtype, @@ -407,12 +428,13 @@ def sequence_numeric_column( columns = [temperature] features = tf.parse_example(..., features=make_parse_example_spec(columns)) - sequence_feature_layer = SequenceFeatureLayer(columns) - input_layer, sequence_length = sequence_feature_layer(features) + sequence_feature_layer = SequenceFeatures(columns) + sequence_input, sequence_length = sequence_feature_layer(features) + sequence_length_mask = tf.sequence_mask(sequence_length) - rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) - outputs, state = tf.nn.dynamic_rnn( - rnn_cell, inputs=input_layer, sequence_length=sequence_length) + rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size) + rnn_layer = tf.keras.layers.RNN(rnn_cell) + outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask) ``` Args: @@ -437,7 +459,7 @@ def sequence_numeric_column( ValueError: if any dimension in shape is not a positive integer. ValueError: if `dtype` is not convertible to `tf.float32`. """ - shape = fc_old._check_shape(shape=shape, key=key) + shape = fc_v2._check_shape(shape=shape, key=key) if not (dtype.is_integer or dtype.is_floating): raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) @@ -532,8 +554,10 @@ class SequenceNumericColumn( # For the 2D case, the raw values are grouped according to num_elements; # for the 3D case, the grouping happens in the third dimension, and # sequence length is not affected. - num_elements = (self.variable_shape.num_elements() - if sp_tensor.shape.ndims == 2 else 1) + if sp_tensor.shape.ndims == 2: + num_elements = self.variable_shape.num_elements() + else: + num_elements = 1 seq_length = fc_old._sequence_length_from_sparse_tensor( sp_tensor, num_elements=num_elements) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_integration_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1b165a620ae67e855400eb297ec17db80eac7937 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_integration_test.py @@ -0,0 +1,283 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration test for sequence feature columns with SequenceExamples.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import string +import tempfile + +from google.protobuf import text_format + +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.keras.layers import recurrent +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class SequenceFeatureColumnIntegrationTest(test.TestCase): + + def _make_sequence_example(self): + example = example_pb2.SequenceExample() + example.context.feature['int_ctx'].int64_list.value.extend([5]) + example.context.feature['float_ctx'].float_list.value.extend([123.6]) + for val in range(0, 10, 2): + feat = feature_pb2.Feature() + feat.int64_list.value.extend([val] * val) + example.feature_lists.feature_list['int_list'].feature.extend([feat]) + for val in range(1, 11, 2): + feat = feature_pb2.Feature() + feat.bytes_list.value.extend([compat.as_bytes(str(val))] * val) + example.feature_lists.feature_list['str_list'].feature.extend([feat]) + + return example + + def _build_feature_columns(self): + col = fc.categorical_column_with_identity('int_ctx', num_buckets=100) + ctx_cols = [ + fc.embedding_column(col, dimension=10), + fc.numeric_column('float_ctx') + ] + + identity_col = sfc.sequence_categorical_column_with_identity( + 'int_list', num_buckets=10) + bucket_col = sfc.sequence_categorical_column_with_hash_bucket( + 'bytes_list', hash_bucket_size=100) + seq_cols = [ + fc.embedding_column(identity_col, dimension=10), + fc.embedding_column(bucket_col, dimension=20) + ] + + return ctx_cols, seq_cols + + def test_sequence_example_into_input_layer(self): + examples = [_make_sequence_example().SerializeToString()] * 100 + ctx_cols, seq_cols = self._build_feature_columns() + + def _parse_example(example): + ctx, seq = parsing_ops.parse_single_sequence_example( + example, + context_features=fc.make_parse_example_spec_v2(ctx_cols), + sequence_features=fc.make_parse_example_spec_v2(seq_cols)) + ctx.update(seq) + return ctx + + ds = dataset_ops.Dataset.from_tensor_slices(examples) + ds = ds.map(_parse_example) + ds = ds.batch(20) + + # Test on a single batch + features = ds.make_one_shot_iterator().get_next() + + # Tile the context features across the sequence features + sequence_input_layer = sfc.SequenceFeatures(seq_cols) + seq_layer, _ = sequence_input_layer(features) + input_layer = fc.DenseFeatures(ctx_cols) + ctx_layer = input_layer(features) + input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer) + + rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10)) + output = rnn_layer(input_layer) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + features_r = sess.run(features) + self.assertAllEqual(features_r['int_list'].dense_shape, [20, 3, 6]) + + output_r = sess.run(output) + self.assertAllEqual(output_r.shape, [20, 10]) + + +class SequenceExampleParsingTest(test.TestCase): + + def test_seq_ex_in_sequence_categorical_column_with_identity(self): + self._test_parsed_sequence_example( + 'int_list', sfc.sequence_categorical_column_with_identity, + 10, [3, 6], [2, 4, 6]) + + def test_seq_ex_in_sequence_categorical_column_with_hash_bucket(self): + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_hash_bucket, + 10, [3, 4], [compat.as_bytes(x) for x in 'acg']) + + def test_seq_ex_in_sequence_categorical_column_with_vocabulary_list(self): + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_list, + list(string.ascii_lowercase), [3, 4], + [compat.as_bytes(x) for x in 'acg']) + + def test_seq_ex_in_sequence_categorical_column_with_vocabulary_file(self): + _, fname = tempfile.mkstemp() + with open(fname, 'w') as f: + f.write(string.ascii_lowercase) + self._test_parsed_sequence_example( + 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_file, + fname, [3, 4], [compat.as_bytes(x) for x in 'acg']) + + def _test_parsed_sequence_example( + self, col_name, col_fn, col_arg, shape, values): + """Helper function to check that each FeatureColumn parses correctly. + + Args: + col_name: string, name to give to the feature column. Should match + the name that the column will parse out of the features dict. + col_fn: function used to create the feature column. For example, + sequence_numeric_column. + col_arg: second arg that the target feature column is expecting. + shape: the expected dense_shape of the feature after parsing into + a SparseTensor. + values: the expected values at index [0, 2, 6] of the feature + after parsing into a SparseTensor. + """ + example = _make_sequence_example() + columns = [ + fc.categorical_column_with_identity('int_ctx', num_buckets=100), + fc.numeric_column('float_ctx'), + col_fn(col_name, col_arg) + ] + context, seq_features = parsing_ops.parse_single_sequence_example( + example.SerializeToString(), + context_features=fc.make_parse_example_spec_v2(columns[:2]), + sequence_features=fc.make_parse_example_spec_v2(columns[2:])) + + with self.cached_session() as sess: + ctx_result, seq_result = sess.run([context, seq_features]) + self.assertEqual(list(seq_result[col_name].dense_shape), shape) + self.assertEqual( + list(seq_result[col_name].values[[0, 2, 6]]), values) + self.assertEqual(list(ctx_result['int_ctx'].dense_shape), [1]) + self.assertEqual(ctx_result['int_ctx'].values[0], 5) + self.assertEqual(list(ctx_result['float_ctx'].shape), [1]) + self.assertAlmostEqual(ctx_result['float_ctx'][0], 123.6, places=1) + + +_SEQ_EX_PROTO = """ +context { + feature { + key: "float_ctx" + value { + float_list { + value: 123.6 + } + } + } + feature { + key: "int_ctx" + value { + int64_list { + value: 5 + } + } + } +} +feature_lists { + feature_list { + key: "bytes_list" + value { + feature { + bytes_list { + value: "a" + } + } + feature { + bytes_list { + value: "b" + value: "c" + } + } + feature { + bytes_list { + value: "d" + value: "e" + value: "f" + value: "g" + } + } + } + } + feature_list { + key: "float_list" + value { + feature { + float_list { + value: 1.0 + } + } + feature { + float_list { + value: 3.0 + value: 3.0 + value: 3.0 + } + } + feature { + float_list { + value: 5.0 + value: 5.0 + value: 5.0 + value: 5.0 + value: 5.0 + } + } + } + } + feature_list { + key: "int_list" + value { + feature { + int64_list { + value: 2 + value: 2 + } + } + feature { + int64_list { + value: 4 + value: 4 + value: 4 + value: 4 + } + } + feature { + int64_list { + value: 6 + value: 6 + value: 6 + value: 6 + value: 6 + value: 6 + } + } + } + } +} +""" + + +def _make_sequence_example(): + example = example_pb2.SequenceExample() + return text_format.Parse(_SEQ_EX_PROTO, example) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py index ca4398a142065de0be7bee57cd7e54670bbae12e..a1feaddcc00d5fac86dca3138dfa1c6314bb6a8b 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py @@ -22,23 +22,23 @@ import os from absl.testing import parameterized import numpy as np -from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc_old from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc -from tensorflow.python.feature_column import feature_column as fc_old from tensorflow.python.feature_column import feature_column_lib as fc -from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.feature_column.feature_column_v2_test import _TestStateManager from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test from tensorflow.python.training import monitored_session -class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): +class SequenceFeaturesTest(test.TestCase, parameterized.TestCase): @parameterized.named_parameters( {'testcase_name': '2D', @@ -111,29 +111,27 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old._embedding_column( + embedding_column_a = fc.embedding_column( categorical_column_a, dimension=embedding_dimension_a, initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_b = fc_old._embedding_column( + embedding_column_b = fc.embedding_column( categorical_column_b, dimension=embedding_dimension_b, initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) - input_layer, sequence_length = sfc.sequence_input_layer( - features={ - 'aaa': sparse_input_a, - 'bbb': sparse_input_b, - }, - # Test that columns are reordered alphabetically. - feature_columns=[embedding_column_b, embedding_column_a]) + # Test that columns are reordered alphabetically. + sequence_input_layer = sfc.SequenceFeatures( + [embedding_column_b, embedding_column_a]) + input_layer, sequence_length = sequence_input_layer({ + 'aaa': sparse_input_a, 'bbb': sparse_input_b,}) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ('sequence_input_layer/aaa_embedding/embedding_weights:0', - 'sequence_input_layer/bbb_embedding/embedding_weights:0'), + self.assertCountEqual( + ('sequence_features/aaa_embedding/embedding_weights:0', + 'sequence_features/bbb_embedding/embedding_weights:0'), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess)) @@ -152,18 +150,17 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old._categorical_column_with_identity( + categorical_column_a = fc.categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old._embedding_column( + embedding_column_a = fc.embedding_column( categorical_column_a, dimension=2) with self.assertRaisesRegexp( ValueError, r'In embedding_column: aaa_embedding\. categorical_column must be of ' - r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): - _, _ = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[embedding_column_a]) + r'type SequenceCategoricalColumn to use SequenceFeatures\.'): + sequence_input_layer = sfc.SequenceFeatures([embedding_column_a]) + _, _ = sequence_input_layer({'aaa': sparse_input}) def test_shared_embedding_column(self): vocabulary_size = 3 @@ -210,21 +207,18 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) # Test that columns are reordered alphabetically. - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_b, categorical_column_a], dimension=embedding_dimension, initializer=_get_initializer(embedding_dimension, embedding_values)) - input_layer, sequence_length = sfc.sequence_input_layer( - features={ - 'aaa': sparse_input_a, - 'bbb': sparse_input_b, - }, - feature_columns=shared_embedding_columns) + sequence_input_layer = sfc.SequenceFeatures(shared_embedding_columns) + input_layer, sequence_length = sequence_input_layer({ + 'aaa': sparse_input_a, 'bbb': sparse_input_b}) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + self.assertCountEqual( + ('aaa_bbb_shared_embedding:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) @@ -248,23 +242,20 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old._categorical_column_with_identity( + categorical_column_a = fc.categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - categorical_column_b = fc_old._categorical_column_with_identity( + categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) with self.assertRaisesRegexp( ValueError, r'In embedding_column: aaa_shared_embedding\. categorical_column must ' - r'be of type _SequenceCategoricalColumn to use sequence_input_layer\.'): - _, _ = sfc.sequence_input_layer( - features={ - 'aaa': sparse_input_a, - 'bbb': sparse_input_b - }, - feature_columns=shared_embedding_columns) + r'be of type SequenceCategoricalColumn to use SequenceFeatures\.'): + sequence_input_layer = sfc.SequenceFeatures(shared_embedding_columns) + _, _ = sequence_input_layer({'aaa': sparse_input_a, + 'bbb': sparse_input_b}) @parameterized.named_parameters( {'testcase_name': '2D', @@ -319,17 +310,15 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size_a) - indicator_column_a = fc_old._indicator_column(categorical_column_a) + indicator_column_a = fc.indicator_column(categorical_column_a) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size_b) - indicator_column_b = fc_old._indicator_column(categorical_column_b) - input_layer, sequence_length = sfc.sequence_input_layer( - features={ - 'aaa': sparse_input_a, - 'bbb': sparse_input_b, - }, - # Test that columns are reordered alphabetically. - feature_columns=[indicator_column_b, indicator_column_a]) + indicator_column_b = fc.indicator_column(categorical_column_b) + # Test that columns are reordered alphabetically. + sequence_input_layer = sfc.SequenceFeatures( + [indicator_column_b, indicator_column_a]) + input_layer, sequence_length = sequence_input_layer({ + 'aaa': sparse_input_a, 'bbb': sparse_input_b}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) @@ -346,17 +335,16 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old._categorical_column_with_identity( + categorical_column_a = fc.categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc_old._indicator_column(categorical_column_a) + indicator_column_a = fc.indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, r'In indicator_column: aaa_indicator\. categorical_column must be of ' - r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): - _, _ = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[indicator_column_a]) + r'type SequenceCategoricalColumn to use SequenceFeatures\.'): + sequence_input_layer = sfc.SequenceFeatures([indicator_column_a]) + _, _ = sequence_input_layer({'aaa': sparse_input}) @parameterized.named_parameters( {'testcase_name': '2D', @@ -375,7 +363,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): # feature 0, ids [[20, 3], [5]] # feature 1, ids [[3], [8]] 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), - 'values': (20, 3, 5., 3., 8.), + 'values': (20., 3., 5., 3., 8.), 'dense_shape': (2, 2, 2)}, 'expected_input_layer': [ [[20.], [3.], [5.], [0.]], @@ -386,11 +374,10 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): self, sparse_input_args, expected_input_layer, expected_sequence_length): sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) - numeric_column = sfc_old.sequence_numeric_column('aaa') + numeric_column = sfc.sequence_numeric_column('aaa') - input_layer, sequence_length = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[numeric_column]) + sequence_input_layer = sfc.SequenceFeatures([numeric_column]) + input_layer, sequence_length = sequence_input_layer({'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) @@ -428,14 +415,13 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): ) def test_numeric_column_multi_dim( self, sparse_input_args, expected_input_layer, expected_sequence_length): - """Tests sequence_input_layer for multi-dimensional numeric_column.""" + """Tests SequenceFeatures for multi-dimensional numeric_column.""" sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) - numeric_column = sfc_old.sequence_numeric_column('aaa', shape=(2, 2)) + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) - input_layer, sequence_length = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[numeric_column]) + sequence_input_layer = sfc.SequenceFeatures([numeric_column]) + input_layer, sequence_length = sequence_input_layer({'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) @@ -454,22 +440,20 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): indices=((0, 0), (1, 0)), values=(1., 10.), dense_shape=(2, 2)) - numeric_column_a = sfc_old.sequence_numeric_column('aaa') - numeric_column_b = sfc_old.sequence_numeric_column('bbb') + numeric_column_a = sfc.sequence_numeric_column('aaa') + numeric_column_b = sfc.sequence_numeric_column('bbb') - _, sequence_length = sfc.sequence_input_layer( - features={ - 'aaa': sparse_input_a, - 'bbb': sparse_input_b, - }, - feature_columns=[numeric_column_a, numeric_column_b]) + sequence_input_layer = sfc.SequenceFeatures( + [numeric_column_a, numeric_column_b]) + _, sequence_length = sequence_input_layer({ + 'aaa': sparse_input_a, 'bbb': sparse_input_b}) with monitored_session.MonitoredSession() as sess: with self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[Condition x == y did not hold element-wise:\] ' - r'\[x \(sequence_input_layer/aaa/sequence_length:0\) = \] \[2 1\] ' - r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'): + r'\[x \(sequence_features/aaa/sequence_length:0\) = \] \[2 1\] ' + r'\[y \(sequence_features/bbb/sequence_length:0\) = \] \[1 1\]'): sess.run(sequence_length) @parameterized.named_parameters( @@ -497,11 +481,10 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): self, sparse_input_args, expected_shape): """Tests that we return a known static shape when we have one.""" sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) - numeric_column = sfc_old.sequence_numeric_column('aaa', shape=(2, 2)) + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) - input_layer, _ = sfc.sequence_input_layer( - features={'aaa': sparse_input}, - feature_columns=[numeric_column]) + sequence_input_layer = sfc.SequenceFeatures([numeric_column]) + input_layer, _ = sequence_input_layer({'aaa': sparse_input}) shape = input_layer.get_shape() self.assertEqual(shape, expected_shape) @@ -534,13 +517,49 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=3) - indicator_column = fc_old._indicator_column(categorical_column) + indicator_column = fc.indicator_column(categorical_column) - input_layer, _ = sfc.sequence_input_layer( - features={'aaa': sparse_input}, feature_columns=[indicator_column]) + sequence_input_layer = sfc.SequenceFeatures([indicator_column]) + input_layer, _ = sequence_input_layer({'aaa': sparse_input}) shape = input_layer.get_shape() self.assertEqual(shape, expected_shape) + def test_compute_output_shape(self): + price1 = sfc.sequence_numeric_column('price1', shape=2) + price2 = sfc.sequence_numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': sparse_tensor.SparseTensor( + indices=[[0, 0, 0], [0, 0, 1], + [0, 1, 0], [0, 1, 1], + [1, 0, 0], [1, 0, 1], + [2, 0, 0], [2, 0, 1], + [3, 0, 0], [3, 0, 1]], + values=[0., 1., 10., 11., 100., 101., 200., 201., 300., 301.], + dense_shape=(4, 3, 2)), + 'price2': sparse_tensor.SparseTensor( + indices=[[0, 0], + [0, 1], + [1, 0], + [2, 0], + [3, 0]], + values=[10., 11., 20., 30., 40.], + dense_shape=(4, 3))} + sequence_features = sfc.SequenceFeatures([price1, price2]) + seq_input, seq_len = sequence_features(features) + self.assertEqual( + sequence_features.compute_output_shape((None, None)), + (None, None, 3)) + self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(lookup_ops.tables_initializer()) + + self.assertAllClose([[[0., 1., 10.], [10., 11., 11.], [0., 0., 0.]], + [[100., 101., 20.], [0., 0., 0.], [0., 0., 0.]], + [[200., 201., 30.], [0., 0., 0.], [0., 0., 0.]], + [[300., 301., 40.], [0., 0., 0.], [0., 0., 0.]]], + self.evaluate(seq_input)) + self.assertAllClose([2, 1, 1, 1], self.evaluate(seq_len)) + class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase): """Tests the utility fn concatenate_context_input.""" @@ -605,8 +624,8 @@ class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase): sfc.concatenate_context_input(context_input, seq_input) -class InputLayerTest(test.TestCase): - """Tests input_layer with sequence feature columns.""" +class DenseFeaturesTest(test.TestCase): + """Tests DenseFeatures with sequence feature columns.""" def test_embedding_column(self): """Tests that error is raised for sequence embedding column.""" @@ -620,16 +639,15 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old._embedding_column( + embedding_column_a = fc.embedding_column( categorical_column_a, dimension=2) with self.assertRaisesRegexp( ValueError, r'In embedding_column: aaa_embedding\. categorical_column must not be ' - r'of type _SequenceCategoricalColumn\.'): - _ = fc_old.input_layer( - features={'aaa': sparse_input}, - feature_columns=[embedding_column_a]) + r'of type SequenceCategoricalColumn\.'): + input_layer = fc.DenseFeatures([embedding_column_a]) + _ = input_layer({'aaa': sparse_input}) def test_indicator_column(self): """Tests that error is raised for sequence indicator column.""" @@ -643,15 +661,14 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc_old._indicator_column(categorical_column_a) + indicator_column_a = fc.indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, r'In indicator_column: aaa_indicator\. categorical_column must not be ' - r'of type _SequenceCategoricalColumn\.'): - _ = fc_old.input_layer( - features={'aaa': sparse_input}, - feature_columns=[indicator_column_a]) + r'of type SequenceCategoricalColumn\.'): + input_layer = fc.DenseFeatures([indicator_column_a]) + _ = input_layer({'aaa': sparse_input}) def _assert_sparse_tensor_value(test_case, expected, actual): @@ -670,6 +687,23 @@ def _assert_sparse_tensor_indices_shape(test_case, expected, actual): test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) +def _get_sequence_dense_tensor(column, features): + return column.get_sequence_dense_tensor( + fc.FeatureTransformationCache(features), None) + + +def _get_sequence_dense_tensor_state(column, features): + state_manager = _TestStateManager() + column.create_state(state_manager) + return column.get_sequence_dense_tensor( + fc.FeatureTransformationCache(features), state_manager) + + +def _get_sparse_tensors(column, features): + return column.get_sparse_tensors( + fc.FeatureTransformationCache(features), None) + + class SequenceCategoricalColumnWithIdentityTest( test.TestCase, parameterized.TestCase): @@ -698,7 +732,7 @@ class SequenceCategoricalColumnWithIdentityTest( expected = sparse_tensor.SparseTensorValue(**expected_args) column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -737,7 +771,7 @@ class SequenceCategoricalColumnWithHashBucketTest( column = sfc.sequence_categorical_column_with_hash_bucket( 'aaa', hash_bucket_size=10) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -790,7 +824,7 @@ class SequenceCategoricalColumnWithVocabularyFileTest( vocabulary_file=self._wire_vocabulary_file_name, vocabulary_size=self._wire_vocabulary_size) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -814,8 +848,7 @@ class SequenceCategoricalColumnWithVocabularyFileTest( input_placeholder_shape[1] = None input_placeholder = array_ops.sparse_placeholder( dtypes.string, shape=input_placeholder_shape) - id_weight_pair = column._get_sparse_tensors( - _LazyBuilder({'aaa': input_placeholder})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': input_placeholder}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -855,7 +888,7 @@ class SequenceCategoricalColumnWithVocabularyListTest( key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) - id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) with monitored_session.MonitoredSession() as sess: @@ -922,16 +955,15 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column( - categorical_column, - dimension=embedding_dimension, + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, initializer=_initializer) - embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + embedding_lookup, _ = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': inputs}) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( + self.assertCountEqual( ('embedding_weights:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) @@ -961,10 +993,11 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column(categorical_column, dimension=2) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) - _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + _, sequence_length = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -988,10 +1021,11 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old._embedding_column(categorical_column, dimension=2) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) - _, sequence_length = embedding_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _, sequence_length = _get_sequence_dense_tensor_state( + embedding_column, {'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual( @@ -1058,22 +1092,18 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) - embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[0] - embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[0] + embedding_lookup_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[0] + embedding_lookup_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[0] global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertItemsEqual(('aaa_bbb_shared_embedding:0',), tuple([v.name for v in global_vars])) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) @@ -1104,17 +1134,13 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b = [2, 1] categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) - sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[1] - sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[1] + sequence_length_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[1] + sequence_length_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[1] with monitored_session.MonitoredSession() as sess: sequence_length_a = sess.run(sequence_length_a) @@ -1155,17 +1181,13 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) - sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input_a - }))[1] - sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( - _LazyBuilder({ - 'bbb': sparse_input_b - }))[1] + sequence_length_a = _get_sequence_dense_tensor( + shared_embedding_columns[0], {'aaa': sparse_input_a})[1] + sequence_length_b = _get_sequence_dense_tensor( + shared_embedding_columns[1], {'bbb': sparse_input_b})[1] with monitored_session.MonitoredSession() as sess: self.assertAllEqual( @@ -1221,10 +1243,10 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old._indicator_column(categorical_column) + indicator_column = fc.indicator_column(categorical_column) - indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + indicator_tensor, _ = _get_sequence_dense_tensor( + indicator_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual(expected, indicator_tensor.eval(session=sess)) @@ -1253,10 +1275,10 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old._indicator_column(categorical_column) + indicator_column = fc.indicator_column(categorical_column) - _, sequence_length = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': inputs})) + _, sequence_length = _get_sequence_dense_tensor( + indicator_column, {'aaa': inputs}) with monitored_session.MonitoredSession() as sess: sequence_length = sess.run(sequence_length) @@ -1282,19 +1304,14 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) - _, sequence_length = indicator_column._get_sequence_dense_tensor( - _LazyBuilder({'aaa': sparse_input})) + _, sequence_length = _get_sequence_dense_tensor( + indicator_column, {'aaa': sparse_input}) with monitored_session.MonitoredSession() as sess: self.assertAllEqual( expected_sequence_length, sequence_length.eval(session=sess)) -def _get_sequence_dense_tensor(column, features): - return column.get_sequence_dense_tensor( - fc.FeatureTransformationCache(features), None) - - class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase): def test_defaults(self): diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index dad50a3a73085526f65bd87c3d8549ceb75b3af4..3f6dbe0cbdeeae5e2107755f80bcfe5f7fc310e4 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -50,6 +50,8 @@ tf_custom_op_py_library( visibility = [ "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", + "//tensorflow_estimator:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", "//video/youtube/personalization:__subpackages__", ], deps = [ diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index e72e50585a3861d4527b66f89e1659d76c85960a..3784631dcbfbeb215b6c695e4b6f1bbd02fa708c 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -130,17 +130,21 @@ _allowed_symbols = ['nest'] _nest_allowed_symbols = [ 'assert_same_structure', 'is_sequence', + 'is_sequence_or_composite', 'flatten', 'flatten_dict_items', 'pack_sequence_as', 'map_structure', 'map_structure_with_paths', + 'map_structure_with_tuple_paths', 'assert_shallow_structure', 'flatten_up_to', 'map_structure_up_to', + 'map_structure_with_tuple_paths_up_to', 'get_traverse_shallow_structure', 'yield_flat_paths', 'flatten_with_joined_string_paths', + 'flatten_with_tuple_paths', ] remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols) diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 57a5bfbf43c915775c6b0ef05baac19581213a09..f65f450eba49163c319af54ec2bd7f6b61e34c1e 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -171,6 +171,7 @@ cuda_py_test( main = "python/ops/fused_conv2d_bias_activation_benchmark.py", tags = [ "manual", # TODO(b/117128481): re-enable after fixing OSS build + "nogpu", "requires-gpu-sm70", ], ) diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 93b1aaa85e88e00c1b12a388321a4d6fb10f1611..b6b75ffa248d66cc4cb49339f193d486f05a6a4a 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -19,13 +19,13 @@ limitations under the License. #include "tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" @@ -522,7 +522,7 @@ void LaunchFusedConv2DBiasActivationOp:: auto bias_ptr = AsDeviceMemory(bias.template flat().data(), bias.template flat().size()); - static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( + static int64 ConvolveScratchSize = GetDnnWorkspaceLimit( // default value is in bytes despite the name of the environment variable "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB ); @@ -570,7 +570,7 @@ void LaunchFusedConv2DBiasActivationOp:: for (auto profile_algorithm : algorithms) { // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); dnn::ProfileResult profile_result; bool cudnn_launch_status = stream @@ -609,7 +609,7 @@ void LaunchFusedConv2DBiasActivationOp:: algorithm_config); } - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); bool cudnn_launch_status = stream ->ThenFusedConvolveWithAlgorithm( diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index f89d7ed0f45f919b17398de5d9449d12c08dd2f2..db0868fb2c43464a811b3d6dfcd96480ba2463ee 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -1,12 +1,14 @@ -# Files for using TFGAN framework. -package(default_visibility = ["//tensorflow:__subpackages__"]) +# Files for using TF-GAN framework. +load("//tensorflow:tensorflow.bzl", "py_test") + +package(default_visibility = [ + "//tensorflow:__subpackages__", +]) licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") - py_library( name = "gan", srcs = [ @@ -104,7 +106,9 @@ py_library( deps = [ ":gan_estimator", ":head", + ":latent_gan_estimator", ":stargan_estimator", + ":tpu_gan_estimator", "//tensorflow/python:util", ], ) @@ -128,6 +132,7 @@ py_library( ":clip_weights", ":conditioning_utils", ":random_tensor_pool", + ":spectral_normalization", ":virtual_batchnorm", "//tensorflow/python:util", ], @@ -141,16 +146,15 @@ py_library( "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:clip_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", + "//tensorflow/python:gradients_impl", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:summary", "//tensorflow/python:tensor_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/losses", - "//third_party/py/numpy", ], ) @@ -518,15 +522,19 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:parsing_ops", "//tensorflow/python:summary", + "//tensorflow/python:tensor_shape", "//tensorflow/python:training", "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", "@six_archive//:six", @@ -562,28 +570,114 @@ py_test( deps = [ ":namedtuples", ":stargan_estimator", - ":tuple_losses", "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/learn", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", - "//tensorflow/python:parsing_ops", "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + "@six_archive//:six", + ], +) + +py_library( + name = "tpu_gan_estimator", + srcs = [ + "python/estimator/python/tpu_gan_estimator.py", + "python/estimator/python/tpu_gan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":gan_estimator", + ":namedtuples", + ":train", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/contrib/training:training_py", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:util", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/ops/losses", + ], +) + +py_test( + name = "tpu_gan_estimator_test", + srcs = ["python/estimator/python/tpu_gan_estimator_test.py"], + shard_count = 11, + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":namedtuples", + ":tpu_gan_estimator", + ":tuple_losses", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:summary", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], ) +py_library( + name = "latent_gan_estimator", + srcs = [ + "python/estimator/python/latent_gan_estimator.py", + "python/estimator/python/latent_gan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":train", + "//tensorflow/python:clip_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:random_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "latent_gan_estimator_test", + srcs = [ + "python/estimator/python/latent_gan_estimator_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":latent_gan_estimator", + "//tensorflow/python:array_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/ops/losses", + ], +) + py_library( name = "sliced_wasserstein", srcs = [ @@ -618,3 +712,45 @@ py_test( "//third_party/py/numpy", ], ) + +py_library( + name = "spectral_normalization", + srcs = [ + "python/features/python/spectral_normalization.py", + "python/features/python/spectral_normalization_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:standard_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/keras:engine", + ], +) + +py_test( + name = "spectral_normalization_test", + srcs = ["python/features/python/spectral_normalization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":spectral_normalization", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/slim", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/keras:layers", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 9ab86329eaf0e6fd426aef1f552f4e27c2ad65de..4eac4e80cdacd779fdbedef19e4a654196f0caf1 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -1,14 +1,15 @@ -# TensorFlow-GAN (TFGAN) + +# TensorFlow-GAN (TF-GAN) -TFGAN is a lightweight library for training and evaluating Generative +TF-GAN is a lightweight library for training and evaluating Generative Adversarial Networks (GANs). This technique allows you to train a network (called the 'generator') to sample from a distribution, without having to explicitly model the distribution and without writing an explicit loss. For example, the generator could learn to draw samples from the distribution of natural images. For more details on this technique, see ['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by -Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an +Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](http://https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an introduction. #### Usage @@ -17,27 +18,27 @@ import tensorflow as tf tfgan = tf.contrib.gan ``` -## Why TFGAN? +## Why TF-GAN? * Easily train generator and discriminator networks with well-tested, flexible [library calls](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py). You can -mix TFGAN, native TF, and other custom frameworks +mix TF-GAN, native TF, and other custom frameworks * Use already implemented [GAN losses and penalties](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/losses_impl.py) (ex Wasserstein loss, gradient penalty, mutual information penalty, etc) * [Monitor and visualize](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/summaries_impl.py) GAN progress during training, and [evaluate](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py) them * Use already-implemented [tricks](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/) to stabilize and improve training * Develop based on examples of [common GAN setups](https://github.com/tensorflow/models/tree/master/research/gan/) -* Use the TFGAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model -* Improvements in TFGAN infrastructure will automatically benefit your TFGAN project +* Use the TF-GAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model +* Improvements in TF-GAN infrastructure will automatically benefit your TF-GAN project * Stay up-to-date with research as we add more algorithms -## What are the TFGAN components? +## What are the TF-GAN components? -TFGAN is composed of several parts which were design to exist independently. +TF-GAN is composed of several parts which were design to exist independently. These include the following main pieces (explained in detail below). * [core](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py): provides the main infrastructure needed to train a GAN. Training occurs in four phases, and each phase can be completed by custom-code or by using a - TFGAN library call. + TF-GAN library call. * [features](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/): Many common GAN operations and normalization techniques are implemented for @@ -56,14 +57,14 @@ These include the following main pieces (explained in detail below). generative models. * [examples](https://github.com/tensorflow/models/tree/master/research/gan/) - and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TFGAN to make - GAN training easier, or use the more complicated examples to jumpstart your + and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TF-GAN to make + GAN training easier, or use the more complicated examples to jump-start your own project. These include unconditional and conditional GANs, InfoGANs, adversarial losses on existing networks, and image-to-image translation. ## Training a GAN model -Training in TFGAN typically consists of the following steps: +Training in TF-GAN typically consists of the following steps: 1. Specify the input to your networks. 1. Set up your generator and discriminator using a `GANModel`. @@ -71,12 +72,12 @@ Training in TFGAN typically consists of the following steps: 1. Create your train ops using a `GANTrainOps`. 1. Run your train ops. -At each stage, you can either use TFGAN's convenience functions, or you can +At each stage, you can either use TF-GAN's convenience functions, or you can perform the step manually for fine-grained control. We provide examples below. There are various types of GAN setups. For instance, you can train a generator to sample unconditionally from a learned distribution, or you can condition on -extra information such as a class label. TFGAN is compatible with many setups, +extra information such as a class label. TF-GAN is compatible with many setups, and we demonstrate a few below: ### Examples @@ -254,9 +255,9 @@ with variable_scope.variable_scope(dis_scope, reuse=True): discriminator_real_outputs = discriminator_fn(images) generator_variables = variables_lib.get_trainable_variables(gen_scope) discriminator_variables = variables_lib.get_trainable_variables(dis_scope) -# Depending on what TFGAN features you use, you don't always need to supply +# Depending on what TF-GAN features you use, you don't always need to supply # every `GANModel` field. At a minimum, you need to include the discriminator -# outputs and variables if you want to use TFGAN to construct losses. +# outputs and variables if you want to use TF-GAN to construct losses. gan_model = tfgan.GANModel( generator_inputs, generated_data, diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index f1946c7f925660eae3aaa650c437e03da1f33d6c..1e6000898f7b8a53ad3f6fa12deebd54bf3a57ff 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN is a lightweight library for training and evaluating GANs. +"""TF-GAN is a lightweight library for training and evaluating GANs. In addition to providing the infrastructure for easily training and evaluating GANS, this library contains modules for a TFGAN-backed Estimator, @@ -24,7 +24,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# Collapse TFGAN into a tiered namespace. +# Collapse TF-GAN into a tiered namespace. from tensorflow.contrib.gan.python import estimator from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin from tensorflow.contrib.gan.python import features diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index 99d38011ba677f03e198a431634fbb2ce349f912..430266555b723e6ca39dccffc1442dbef5d4a385 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN estimator module. +"""TF-GAN estimator module. GANEstimator provides all the infrastructure support of a TensorFlow Estimator -with the feature support of TFGAN. +with the feature support of TF-GAN. """ from __future__ import absolute_import @@ -26,18 +26,25 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.estimator.python import gan_estimator from tensorflow.contrib.gan.python.estimator.python import head +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator from tensorflow.contrib.gan.python.estimator.python import stargan_estimator +from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * from tensorflow.contrib.gan.python.estimator.python.head import * +from tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator import * from tensorflow.contrib.gan.python.estimator.python.stargan_estimator import * +from tensorflow.contrib.gan.python.estimator.python.tpu_gan_estimator import * # pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = [ +_allowed_symbols = ([ 'gan_estimator', 'stargan_estimator', + 'tpu_gan_estimator', + 'latent_gan_estimator', 'head', -] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__ +] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__ + + tpu_gan_estimator.__all__ + latent_gan_estimator.__all__) remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 3593b501bb738b8f58dce4e40cffbdf410f136b3..dd904611d1a3bb78de8316d5ed29ab0f800f29a9 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed GAN Estimator.""" +"""A TF-GAN-backed GAN Estimator.""" from __future__ import absolute_import from __future__ import division @@ -56,10 +56,10 @@ _summary_type_map = { class GANEstimator(estimator.Estimator): """An estimator for Generative Adversarial Networks (GANs). - This Estimator is backed by TFGAN. The network functions follow the TFGAN API - except for one exception: if either `generator_fn` or `discriminator_fn` have - an argument called `mode`, then the tf.Estimator mode is passed in for that - argument. This helps with operations like batch normalization, which have + This Estimator is backed by TF-GAN. The network functions follow the TF-GAN + API except for one exception: if either `generator_fn` or `discriminator_fn` + have an argument called `mode`, then the tf.Estimator mode is passed in for + that argument. This helps with operations like batch normalization, which have different train and evaluation behavior. Example: @@ -68,7 +68,7 @@ class GANEstimator(estimator.Estimator): import tensorflow as tf tfgan = tf.contrib.gan - # See TFGAN's `train.py` for a description of the generator and + # See TF-GAN's `train.py` for a description of the generator and # discriminator API. def generator_fn(generator_inputs): ... @@ -123,13 +123,13 @@ class GANEstimator(estimator.Estimator): to continue training a previously saved model. generator_fn: A python function that takes a Tensor, Tensor list, or Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TFGAN` for more details and examples. Additionally, if + generator. See `TF-GAN` for more details and examples. Additionally, if it has an argument called `mode`, the Estimator's `mode` will be passed in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch normalization. discriminator_fn: A python function that takes the output of `generator_fn` or real data in the GAN setup, and `generator_inputs`. - Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details + Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details and examples. generator_loss_fn: The loss function on the generator. Takes a `GANModel` tuple. @@ -233,13 +233,14 @@ def _get_estimator_spec( estimator_spec = _get_eval_estimator_spec( gan_model, gan_loss, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: - gopt = (generator_optimizer() if callable(generator_optimizer) else - generator_optimizer) - dopt = (discriminator_optimizer() if callable(discriminator_optimizer) - else discriminator_optimizer) + if callable(generator_optimizer): + generator_optimizer = generator_optimizer() + if callable(discriminator_optimizer): + discriminator_optimizer = discriminator_optimizer() get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() estimator_spec = _get_train_estimator_spec( - gan_model, gan_loss, gopt, dopt, get_hooks_fn, is_chief=is_chief) + gan_model, gan_loss, generator_optimizer, discriminator_optimizer, + get_hooks_fn, is_chief=is_chief) return estimator_spec diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index bc9021050bc010ce75c3091fef868549686c0e90..5b9c54e43a16adf457d5ed0e7e73dcd168ab0d67 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's estimator.py.""" +"""Tests for TF-GAN's estimator.py.""" from __future__ import absolute_import from __future__ import division @@ -75,8 +75,8 @@ class GetGANModelTest(test.TestCase, parameterized.TestCase): def test_get_gan_model(self, mode): with ops.Graph().as_default(): generator_inputs = {'x': array_ops.ones([3, 4])} - real_data = (array_ops.zeros([3, 4]) if - mode != model_fn_lib.ModeKeys.PREDICT else None) + is_predict = mode == model_fn_lib.ModeKeys.PREDICT + real_data = array_ops.zeros([3, 4]) if not is_predict else None gan_model = estimator._get_gan_model( mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries=False) @@ -139,6 +139,7 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): @classmethod def setUpClass(cls): + super(GetEstimatorSpecTest, cls).setUpClass() cls._generator_optimizer = training.GradientDescentOptimizer(1.0) cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) @@ -200,7 +201,6 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) -# TODO(joelshor): Add pandas test. class GANEstimatorIntegrationTest(test.TestCase): def setUp(self): @@ -231,11 +231,11 @@ class GANEstimatorIntegrationTest(test.TestCase): get_eval_metric_ops_fn=get_metrics, model_dir=self._model_dir) - # TRAIN + # Train. num_steps = 10 est.train(train_input_fn, steps=num_steps) - # EVALUTE + # Evaluate. scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) @@ -243,7 +243,7 @@ class GANEstimatorIntegrationTest(test.TestCase): scores['loss']) self.assertIn('mse_custom_metric', six.iterkeys(scores)) - # PREDICT + # Predict. predictions = np.array([x for x in est.predict(predict_input_fn)]) self.assertAllEqual(prediction_size, predictions.shape) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 1a0ee6dfc498eb6dc8c97411589d9e35bc352062..cbe990b476c3b17ce61e0826b17d10976fea43c7 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed GAN Estimator.""" +"""A TF-GAN-backed GAN Estimator.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 8205bc889dc01c8680e2139393d65723280cfbd0..5b50234a0e33cd297b176f142b358338966b6758 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's head.py.""" +"""Tests for TF-GAN's head.py.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..4e164e24168bb0cc5e9a7cc772081781ea088bb1 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `Train Input Estimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = latent_gan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..f5afc7731937ed1a82c8ebb5969b2687ffdd583b --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py @@ -0,0 +1,205 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements an estimator wrapper that allows training the input latent space. + +This file implements a latent gan estimator that wraps around a previously +trained GAN. The latent gan estimator trains a single variable z, representing +the hidden latent distribution that is the 'noise' input to the GAN. By training +z, the inpainting estimator can move around the latent z space towards +minimizing a specific loss function. + +The latent gan estimator has a few key differences from a normal estimator. + +First: the variables in the estimator should not be saved, as we are not +updating the original GAN and are only adding a new z variable that is meant +to be different for each run. In order to do distributed training using +train_and_evaluate, the Tensorflow RunConfig is expected to save checkpoints +by having either save_checkpoints_steps or save_checkpoints_secs saved. +To avoid this conflict, we purposely set the save_checkpoints_steps value in +the RunConfig to be one step more than the total number of steps that the +inpainter estimator will run. + +Second: we need to specify warm start settings, as we are reloading the +GAN model into a different graph (specifically, one with a new z variable). +The warm start settings defined below reload all GAN variables and ignore the +new z variable (and the optimizer). + +Usage: + + def _generator(net, mode): + ... + + def _discriminator(net, condition, mode): + ... + + def _loss(gan_model, features, labels, add_summaries): + ... + + def optimizer(): + ... + + params = {} + config = tf.estimator.RunConfig() + tmp_dir = path/to/output/storage + + estimator = latent_gan_estimator.get_latent_gan_estimator( + _generator, _discriminator, _loss, optimizer, params, config, tmp_dir) + + def input_fn(): + ... + + estimator.train(input_fn=input_fn) + +See latent_gan_estimator_test.py or tensorflow_models/gan/face_inpainting for +further examples. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import training_util + + +INPUT_NAME = 'new_var_z_input' # The name for the new z space input variable. +OPTIMIZER_NAME = 'latent_gan_optimizer' # The name for the new optimizer vars. + +__all__ = [ + 'get_latent_gan_estimator', +] + + +def _get_latent_gan_model_fn(generator_fn, discriminator_fn, loss_fn, + optimizer): + """Sets up a model function that wraps around a given GAN.""" + def model_fn(features, labels, mode, params): + """Model function defining an inpainting estimator.""" + batch_size = params['batch_size'] + z_shape = [batch_size] + params['z_shape'] + add_summaries = params['add_summaries'] + input_clip = params['input_clip'] + + z = variable_scope.get_variable( + name=INPUT_NAME, initializer=random_ops.truncated_normal(z_shape), + constraint=lambda x: clip_ops.clip_by_value(x, -input_clip, input_clip)) + + generator = functools.partial(generator_fn, mode=mode) + discriminator = functools.partial(discriminator_fn, mode=mode) + gan_model = tfgan_train.gan_model(generator_fn=generator, + discriminator_fn=discriminator, + real_data=labels, + generator_inputs=z, + check_shapes=False) + + loss = loss_fn(gan_model, features, labels, add_summaries) + + # Use a variable scope to make sure that estimator variables dont cause + # save/load problems when restoring from ckpts. + with variable_scope.variable_scope(OPTIMIZER_NAME): + opt = optimizer(learning_rate=params['learning_rate'], + **params['opt_kwargs']) + train_op = opt.minimize( + loss=loss, global_step=training_util.get_or_create_global_step(), + var_list=[z]) + + if add_summaries: + z_grads = gradients_impl.gradients(loss, z) + summary.scalar('z_loss/z_grads', clip_ops.global_norm(z_grads)) + summary.scalar('z_loss/loss', loss) + + return model_fn_lib.EstimatorSpec(mode=mode, + predictions=gan_model.generated_data, + loss=loss, + train_op=train_op) + return model_fn + + +def get_latent_gan_estimator(generator_fn, discriminator_fn, loss_fn, + optimizer, params, config, ckpt_dir, + warmstart_options=True): + """Gets an estimator that passes gradients to the input. + + This function takes in a generator and adds a trainable z variable that is + used as input to this generator_fn. The generator itself is treated as a black + box through which gradients can pass through without updating any weights. The + result is a trainable way to traverse the GAN latent space. The loss_fn is + used to actually train the z variable. The generator_fn and discriminator_fn + should be previously trained by the tfgan library (on reload, the variables + are expected to follow the tfgan format. It may be possible to use the + latent gan estimator with entirely custom GANs that do not use the tfgan + library as long as the appropriate variables are wired properly). + + Args: + generator_fn: a function defining a Tensorflow graph for a GAN generator. + The weights defined in this graph should already be defined in the given + checkpoint location. Should have 'mode' as an argument. + discriminator_fn: a function defining a Tensorflow graph for a GAN + discriminator. Should have 'mode' as an argument. + loss_fn: a function defining a Tensorflow graph for a GAN loss. Takes in a + GANModel tuple, features, labels, and add_summaries as inputs. + optimizer: a tf.Optimizer or a function that returns a tf.Optimizer with no + inputs. + params: An object containing the following parameters: + - batch_size: an int indicating the size of the training batch. + - z_shape: the desired shape of the input z values (not counting batch). + - learning_rate: a scalar or function defining a learning rate applied to + optimizer. + - input_clip: the amount to clip the x training variable by. + - add_summaries: whether or not to add summaries. + - opt_kwargs: optimizer kwargs. + config: tf.RunConfig. Should point model to output dir and should indicate + whether to save checkpoints (to avoid saving checkpoints, set + save_checkpoints_steps to a number larger than the number of train steps). + The model_dir field in the RunConfig should point to a directory WITHOUT + any saved checkpoints. + ckpt_dir: the directory where the model checkpoints live. The checkpoint is + used to warm start the underlying GAN. This should NOT be the same as + config.model_dir. + warmstart_options: boolean, None, or a WarmStartSettings object. If set to + True, uses a default WarmStartSettings object. If set to False or None, + does not use warm start. If using a custom WarmStartSettings object, make + sure that new variables are properly accounted for when reloading the + underlying GAN. Defaults to True. + Returns: + An estimator spec defining a GAN input training estimator. + """ + model_fn = _get_latent_gan_model_fn(generator_fn, discriminator_fn, + loss_fn, optimizer) + + if isinstance(warmstart_options, estimator.WarmStartSettings): + ws = warmstart_options + elif warmstart_options: + # Default WarmStart loads all variable names except INPUT_NAME and + # OPTIMIZER_NAME. + var_regex = '^(?!.*(%s|%s).*)' % (INPUT_NAME, OPTIMIZER_NAME) + ws = estimator.WarmStartSettings(ckpt_to_initialize_from=ckpt_dir, + vars_to_warm_start=var_regex) + else: + ws = None + + if 'opt_kwargs' not in params: + params['opt_kwargs'] = {} + + return estimator.Estimator(model_fn=model_fn, config=config, params=params, + warm_start_from=ws) diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac139e532e35f7aae6da0655103a7249fe3382d4 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py @@ -0,0 +1,119 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for latent_gan_estimator. + +See g3.tp.tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import numpy as np +from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator +from tensorflow.python.estimator import run_config as run_config +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +class TrainInputEstimatorTest(test.TestCase): + + def test_get_input_training_estimator(self): + """Integration test to make sure the input_training_estimator works.""" + + # Create dummy test input tensors. + true_features = np.reshape(np.random.uniform(size=100), (10, 10)) + true_labels = np.reshape(np.random.uniform(size=100), (5, 20)) + expected_z_output = [[1, -1], [-1, 1]] + + # Fill out required parameters randomly, includes optimizer kwargs. + params = { + 'batch_size': 2, + 'z_shape': [2], + 'learning_rate': 1.0, + 'input_clip': 1.0, + 'add_summaries': False, + 'opt_kwargs': { + 'beta1': 0.1 + } + } + + input_z_shape = [params['batch_size']] + params['z_shape'] + + # Create dummy model functions that represent an underlying GANEstimator and + # the input training wrapper. Make sure that everything is wired up + # correctly in the internals of each dummy function. + def _generator(net, mode): + """The generator function will get the newly created z variable.""" + del mode + self.assertSequenceEqual(net.shape, input_z_shape) + gen_dummy_var = variable_scope.get_variable( + name='generator_dummy_variable', + initializer=array_ops.ones(input_z_shape)) + return net * gen_dummy_var + + def _discriminator(net, condition, mode): + """The discriminator function will get either the z variable or labels.""" + del condition, mode + try: + self.assertSequenceEqual(net.shape, true_labels.shape) + except AssertionError: + self.assertSequenceEqual(net.shape, input_z_shape) + return net + + def _loss(gan_model, features, labels, _): + """Make sure that features and labels are passed in from input.""" + self.assertTrue(np.array_equal(features, true_features)) + self.assertTrue(np.array_equal(labels, true_labels)) + return losses.absolute_difference(expected_z_output, + gan_model.generated_data) + + optimizer = training.AdamOptimizer + + # We are not loading checkpoints, so set the corresponding directory to a + # dummy directories. + tmp_dir = tempfile.mkdtemp() + config = run_config.RunConfig(model_dir=tmp_dir, + save_summary_steps=None, + save_checkpoints_steps=1, + save_checkpoints_secs=None) + + # Get the estimator. Disable warm start so that there is no attempted + # checkpoint reloading. + estimator = latent_gan_estimator.get_latent_gan_estimator( + _generator, _discriminator, _loss, optimizer, params, config, tmp_dir, + warmstart_options=None) + + # Train for a few steps. + def dummy_input(): + return true_features, true_labels + estimator.train(input_fn=dummy_input, steps=10) + + # Make sure the generator variables did not change, but the z variables did + # change. + self.assertTrue(np.array_equal( + estimator.get_variable_value('Generator/generator_dummy_variable'), + np.ones(input_z_shape))) + self.assertTrue(np.array_equal( + estimator.get_variable_value('new_var_z_input'), + expected_z_output)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py index f60e16bc04662b33bc0bb22b5acc8c7fcc7a03ba..2a485e7d47ff10cf34c1b44f8dcc6b1f33c9a05f 100644 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed StarGAN Estimator.""" +"""A TF-GAN-backed StarGAN Estimator.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py index 2ec7938c7c4051842c7e982b54c1213b6e841b79..c00ff4399748a77f88d9753df7592bf3859d754e 100644 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN's stargan_estimator.py.""" +"""Tests for TF-GAN's stargan_estimator.py.""" from __future__ import absolute_import from __future__ import division @@ -80,7 +80,7 @@ class StarGetGANModelTest(test.TestCase, parameterized.TestCase): self.assertEqual(input_data, gan_model.input_data) self.assertIsNotNone(gan_model.generated_data) self.assertIsNotNone(gan_model.generated_data_domain_target) - self.assertEqual(1, len(gan_model.generator_variables)) + self.assertLen(gan_model.generator_variables, 1) self.assertIsNotNone(gan_model.generator_scope) self.assertIsNotNone(gan_model.generator_fn) if mode == model_fn_lib.ModeKeys.PREDICT: @@ -109,7 +109,7 @@ class StarGetGANModelTest(test.TestCase, parameterized.TestCase): gan_model.discriminator_input_data_domain_predication) self.assertIsNotNone( gan_model.discriminator_generated_data_domain_predication) - self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer + self.assertLen(gan_model.discriminator_variables, 2) # 1 FC layer self.assertIsNotNone(gan_model.discriminator_scope) self.assertIsNotNone(gan_model.discriminator_fn) @@ -163,6 +163,7 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): @classmethod def setUpClass(cls): + super(GetEstimatorSpecTest, cls).setUpClass() cls._generator_optimizer = training.GradientDescentOptimizer(1.0) cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..deb381f7be3f9545ed918813ee55aede946f22d4 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `TPUGANEstimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.tpu_gan_estimator_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = tpu_gan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2a22c78a304c7cc66ef069a235483e9279b3b2 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py @@ -0,0 +1,423 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TF-GAN-backed GAN Estimator that works on TPU.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as gan_estimator_lib +from tensorflow.contrib.tpu.python.tpu import tpu_estimator +from tensorflow.contrib.tpu.python.tpu import tpu_optimizer +from tensorflow.contrib.training.python.training import training +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops.losses import losses + +__all__ = [ + 'TPUGANEstimator', +] + + +class TPUGANEstimator(tpu_estimator.TPUEstimator): + """An estimator for Generative Adversarial Networks (GANs) on TPU. + + This Estimator is backed by TFGAN. It is similar to `tfgan.GANEstimator`, + but works on TPU. + + Example: + + ```python + import tensorflow as tf + tfgan = tf.contrib.gan + + # See TFGAN's `train.py` for a description of the generator and + # discriminator API. + def generator_fn(generator_inputs): + ... + return generated_data + + def discriminator_fn(data, conditioning): + ... + return logits + + # Create GAN estimator. + config = tpu_config.RunConfig(model_dir='/my/dir') + gan_estimator = tfgan.estimator.TPUGANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, + generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + train_batch_size=4, + config=config) + + # Train estimator. + gan_estimator.train(train_input_fn, train_steps) + + # Evaluate resulting estimator. + gan_estimator.evaluate(eval_input_fn, eval_steps) + + # Generate samples from generator. + predictions = np.array([ + x['generated_data'] for x in gan_estimator.predict(predict_input_fn)]) + ``` + """ + + def __init__(self, + # Arguments to construct the `model_fn`. + generator_fn=None, + discriminator_fn=None, + generator_loss_fn=None, + discriminator_loss_fn=None, + generator_optimizer=None, + discriminator_optimizer=None, + get_eval_metric_ops_fn=None, + add_summaries=None, + joint_train=False, + gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1), + # TPUEstimator options. + model_dir=None, + config=None, + params=None, + use_tpu=True, + train_batch_size=None, + eval_batch_size=None, + predict_batch_size=None, + batch_axis=None, + eval_on_tpu=True, + export_to_tpu=True, + warm_start_from=None): + """Initializes a TPUGANEstimator instance. + + Args: + generator_fn: A python function that takes a Tensor, Tensor list, or + Tensor dictionary as inputs and returns the outputs of the GAN + generator. See `TFGAN` for more details and examples. Additionally, if + it has an argument called `mode`, the Estimator's `mode` will be passed + in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch + normalization. + discriminator_fn: A python function that takes the output of + `generator_fn` or real data in the GAN setup, and `generator_inputs`. + Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details + and examples. + generator_loss_fn: The loss function on the generator. Takes a `GANModel` + tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + `GANModel` tuple. + generator_optimizer: The optimizer for generator updates, or a function + that takes no arguments and returns an optimizer. This function will + be called when the default graph is the `GANEstimator`'s graph, so + utilities like `tf.contrib.framework.get_or_create_global_step` will + work. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + get_eval_metric_ops_fn: A function that takes a list of arguments and + returns a dict of metric results keyed by name. The output of this + function is passed into `tf.estimator.EstimatorSpec` during evaluation. + The arguments must be: + * generator_inputs + * generated_data + * real_data + * discriminator_real_outputs + * discriminator_gen_outputs + add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. + This is ignored for jobs that run on TPU, such as the train job if + `use_tpu` is `True` or the eval job if `eval_on_tpu` is `True`. + joint_train: A Python boolean. If `True`, jointly train the generator and + the discriminator. If `False`, sequentially train them. See `train.py` + in TFGAN for more details on the differences between the two GAN + training methods. + gan_train_steps: A `tfgan.GANTrainSteps` named tuple describing the ratio + of generator to discriminator steps. For now, only supports 1:1 + training. + model_dir: Same as `TPUEstimator`: Directory to save model parameters, + graph and etc. This can also be used to load checkpoints from the + directory into a estimator to continue training a previously saved + model. If `None`, the model_dir in `config` will be used if set. If both + are set, they must be same. If both are `None`, a temporary directory + will be used. + config: Same as `TPUEstimator`: An `tpu_config.RunConfig` configuration + object. Cannot be `None`. + params: Same as `TPUEstimator`: An optional `dict` of hyper parameters + that will be passed into `input_fn` and `model_fn`. Keys are names of + parameters, values are basic python types. There are reserved keys for + `TPUEstimator`, including 'batch_size'. + use_tpu: Same as `TPUEstimator`: A bool indicating whether TPU support is + enabled. Currently, TPU training and evaluation respect this bit, but + eval_on_tpu can override execution of eval. See below. Predict still + happens on CPU. + train_batch_size: Same as `TPUEstimator`: An int representing the global + training batch size. TPUEstimator transforms this global batch size to a + per-shard batch size, as params['batch_size'], when calling `input_fn` + and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be + divisible by total number of replicas. + eval_batch_size: Same as `TPUEstimator`: An int representing evaluation + batch size. Must be divisible by total number of replicas. + predict_batch_size: Same as `TPUEstimator`: An int representing the + prediction batch size. Must be divisible by total number of replicas. + batch_axis: Same as `TPUEstimator`: A python tuple of int values + describing how each tensor produced by the Estimator `input_fn` should + be split across the TPU compute shards. For example, if your input_fn + produced (images, labels) where the images tensor is in `HWCN` format, + your shard dimensions would be [3, 0], where 3 corresponds to the `N` + dimension of your images Tensor, and 0 corresponds to the dimension + along which to split the labels to match up with the corresponding + images. If None is supplied, and per_host_input_for_training is True, + batches will be sharded based on the major dimension. If + tpu_config.per_host_input_for_training is False or `PER_HOST_V2`, + batch_axis is ignored. + eval_on_tpu: Same as `TPUEstimator`: If False, evaluation runs on CPU or + GPU. In this case, the model_fn must return `EstimatorSpec` when called + with `mode` as `EVAL`. + export_to_tpu: Same as `TPUEstimator`: If True, `export_savedmodel()` + exports a metagraph for serving on TPU besides the one on CPU. + warm_start_from: Same as `TPUEstimator`: Optional string filepath to a + checkpoint or SavedModel to warm-start from, or a + `tf.estimator.WarmStartSettings` object to fully configure + warm-starting. If the string filepath is provided instead of a + `WarmStartSettings`, then all variables are warm-started, and it is + assumed that vocabularies and Tensor names are unchanged. + + Raises: + ValueError: If loss functions aren't callable. + ValueError: If `gan_train_steps` isn't a `tfgan_tuples.GANTrainSteps` + tuple. + ValueError: If `gan_train_steps` isn't 1:1 training. + """ + if not callable(generator_loss_fn): + raise ValueError('generator_loss_fn must be callable.') + if not callable(discriminator_loss_fn): + raise ValueError('discriminator_loss_fn must be callable.') + if not isinstance(gan_train_steps, tfgan_tuples.GANTrainSteps): + raise ValueError( + '`gan_train_steps` must be `tfgan_tuples.GANTrainSteps`. Instead, ' + 'was type: %s' % type(gan_train_steps)) + if (gan_train_steps.generator_train_steps != 1 or + gan_train_steps.discriminator_train_steps != 1): + raise ValueError('Estimator currently only supports 1:1 training.') + + if use_tpu: + generator_optimizer = _maybe_make_cross_shard_optimizer( + generator_optimizer) + discriminator_optimizer = _maybe_make_cross_shard_optimizer( + discriminator_optimizer) + + def _model_fn(features, labels, mode, params): + """GANEstimator model function.""" + del params # unused + if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT]: + raise ValueError('Mode not recognized: %s' % mode) + real_data = labels # rename inputs for clarity + generator_inputs = features # rename inputs for clarity + + # Make GANModel, which encapsulates the GAN model architectures. + # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then + # remove `add_summaries` logic below. + is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu) + gan_model = gan_estimator_lib._get_gan_model( # pylint:disable=protected-access + mode, generator_fn, discriminator_fn, real_data, generator_inputs, + add_summaries=None if is_on_tpu else add_summaries) + + # Make the TPUEstimatorSpec, which incorporates the GANModel, losses, eval + # metrics, and optimizers (if required). + estimator_spec = _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + joint_train, is_on_tpu, gan_train_steps) + assert isinstance(estimator_spec, tpu_estimator.TPUEstimatorSpec) + return estimator_spec + + super(TPUGANEstimator, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config, + params=params, + use_tpu=use_tpu, + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + predict_batch_size=predict_batch_size, + batch_axis=batch_axis, + eval_on_tpu=eval_on_tpu, + export_to_tpu=export_to_tpu, + warm_start_from=warm_start_from) + + +def _is_on_tpu(mode, use_tpu, eval_on_tpu): + if mode == model_fn_lib.ModeKeys.TRAIN: + return use_tpu + elif mode == model_fn_lib.ModeKeys.EVAL: + return eval_on_tpu + else: + return False + + +def _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + joint_train, is_on_tpu, gan_train_steps): + """Get the TPUEstimatorSpec for the current mode.""" + if mode == model_fn_lib.ModeKeys.PREDICT: + estimator_spec = tpu_estimator.TPUEstimatorSpec( + mode=mode, predictions={'generated_data': gan_model.generated_data}) + elif mode == model_fn_lib.ModeKeys.EVAL: + gan_loss = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn( + gan_model, add_summaries=not is_on_tpu), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=not is_on_tpu)) + # Eval losses for metrics must preserve batch dimension. + gan_loss_no_reduction = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn( + gan_model, add_summaries=False, reduction=losses.Reduction.NONE), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=False, reduction=losses.Reduction.NONE)) + estimator_spec = _get_eval_estimator_spec( + gan_model, gan_loss, gan_loss_no_reduction, get_eval_metric_ops_fn) + else: # model_fn_lib.ModeKeys.TRAIN: + gan_loss = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn( + gan_model, add_summaries=not is_on_tpu), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=not is_on_tpu)) + + # Construct optimizers if arguments were callable. For TPUs, they must be + # `CrossShardOptimizer`. + g_callable = callable(generator_optimizer) + gopt = generator_optimizer() if g_callable else generator_optimizer + d_callable = callable(discriminator_optimizer) + dopt = discriminator_optimizer() if d_callable else discriminator_optimizer + + estimator_spec = _get_train_estimator_spec( + gan_model, gan_loss, gopt, dopt, joint_train, gan_train_steps) + + return estimator_spec + + +def _get_eval_estimator_spec(gan_model, gan_loss, gan_loss_no_reduction, + get_eval_metric_ops_fn): + """Return an TPUEstimatorSpec for the eval case.""" + # Make the metric function and tensor names. + if get_eval_metric_ops_fn is not None: + def metric_fn( + generator_inputs, generated_data, real_data, discriminator_real_outputs, + discriminator_gen_outputs, generator_loss, discriminator_loss): + """`metric_fn` used in TPUEstimator to calculate metrics.""" + eval_metric_ops = { + 'generator_loss': metrics_lib.mean(generator_loss), + 'discriminator_loss': metrics_lib.mean(discriminator_loss), + } + custom_eval_metric_ops = get_eval_metric_ops_fn( + generator_inputs, generated_data, real_data, + discriminator_real_outputs, discriminator_gen_outputs) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('`get_eval_metric_ops_fn` must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) + return eval_metric_ops + tensors = { + 'generator_loss': gan_loss_no_reduction.generator_loss, + 'discriminator_loss': gan_loss_no_reduction.discriminator_loss, + 'generator_inputs': gan_model.generator_inputs, + 'generated_data': gan_model.generated_data, + 'real_data': gan_model.real_data, + 'discriminator_real_outputs': gan_model.discriminator_real_outputs, + 'discriminator_gen_outputs': gan_model.discriminator_gen_outputs, + } + else: + def metric_fn(generator_loss, discriminator_loss): + return { + 'generator_loss': metrics_lib.mean(generator_loss), + 'discriminator_loss': metrics_lib.mean(discriminator_loss), + } + tensors = { + 'generator_loss': gan_loss_no_reduction.generator_loss, + 'discriminator_loss': gan_loss_no_reduction.discriminator_loss, + } + + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + return tpu_estimator.TPUEstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, + predictions=gan_model.generated_data, + loss=scalar_loss, + eval_metrics=(metric_fn, tensors)) + + +def _get_train_estimator_spec( + gan_model, gan_loss, generator_optimizer, discriminator_optimizer, + joint_train, gan_train_steps): + """Return a TPUEstimatorSpec for the train case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + + # Get generator and discriminator update ops. We split them so that update + # ops aren't accidentally run multiple times. For now, throw an error if + # there are update ops that aren't associated with either the generator or + # the discriminator. Might modify the `kwargs` dictionary. + gen_update_ops, dis_update_ops = tfgan_train._get_update_ops( # pylint:disable=protected-access + {}, gan_model.generator_scope.name, gan_model.discriminator_scope.name) + + def gen_train_op(): + with ops.name_scope('generator_train'): + return training.create_train_op( + total_loss=gan_loss.generator_loss, + optimizer=generator_optimizer, + variables_to_train=gan_model.generator_variables, + update_ops=gen_update_ops) + def dis_train_op(): + with ops.name_scope('discriminator_train'): + return training.create_train_op( + total_loss=gan_loss.discriminator_loss, + optimizer=discriminator_optimizer, + variables_to_train=gan_model.discriminator_variables, + update_ops=dis_update_ops) + + # Either optimize the generator and discriminator sequentially or jointly. + tpu_train_op = _combine_train_ops(gen_train_op, dis_train_op, joint_train, + gan_train_steps) + + return tpu_estimator.TPUEstimatorSpec( + loss=scalar_loss, + mode=model_fn_lib.ModeKeys.TRAIN, + train_op=tpu_train_op) + + +# TODO(joelshor): Add support for multiple D / G steps. +def _combine_train_ops(gen_train_op, dis_train_op, joint_train, + gan_train_steps): + """Combine generator and discriminator train ops into a single op.""" + del gan_train_steps + if joint_train: + tpu_train_op = control_flow_ops.group(gen_train_op(), dis_train_op(), + name='joint_train') + else: + with ops.control_dependencies([dis_train_op()]): + tpu_train_op = gen_train_op() + + return tpu_train_op + + +def _maybe_make_cross_shard_optimizer(opt): + if callable(opt): + if not isinstance(opt(), tpu_optimizer.CrossShardOptimizer): + return lambda: tpu_optimizer.CrossShardOptimizer(opt()) + elif not isinstance(opt, tpu_optimizer.CrossShardOptimizer): + return tpu_optimizer.CrossShardOptimizer(opt) + return opt diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9e6489bdd1d89cc49bfedc2eed784999c31d2b --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py @@ -0,0 +1,319 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TF-GAN's TPU Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib import layers +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator_impl as estimator +from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses +from tensorflow.contrib.tpu.python.tpu import tpu_config +from tensorflow.contrib.tpu.python.tpu import tpu_estimator +from tensorflow.contrib.tpu.python.tpu import tpu_optimizer +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.estimator import WarmStartSettings +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework.errors_impl import NotFoundError +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import flags +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import training +from tensorflow.python.training import training_util + +FLAGS = flags.FLAGS + +flags.DEFINE_bool('use_tpu', False, 'Whether to run test on TPU or not.') + + +def generator_fn(noise, mode): + del mode + return layers.fully_connected(noise, tensor_shape.dimension_value( + noise.shape[1])) + + +def discriminator_fn(data, unused_conditioning, mode): + del unused_conditioning, mode + return layers.fully_connected(data, 1) + + +def get_dummy_gan_model(): + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) + with variable_scope.variable_scope('discriminator') as dis_scope: + dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) + return tfgan_tuples.GANModel( + generator_inputs=None, + generated_data=array_ops.ones([3, 4]), + generator_variables=[gen_var], + generator_scope=gen_scope, + generator_fn=None, + real_data=array_ops.zeros([3, 4]), + discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, + discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, + discriminator_variables=[dis_var], + discriminator_scope=dis_scope, + discriminator_fn=None) + + +def get_metrics(generator_inputs, generated_data, real_data, + discriminator_real_outputs, discriminator_gen_outputs): + del generator_inputs, discriminator_real_outputs, discriminator_gen_outputs + return { + 'mse_custom_metric': metrics_lib.mean_squared_error( + real_data, generated_data) + } + + +class GetTPUEstimatorSpecTest(test.TestCase, parameterized.TestCase): + """Tests that the EstimatorSpec is constructed appropriately.""" + + @classmethod + def setUpClass(cls): + super(GetTPUEstimatorSpecTest, cls).setUpClass() + cls._generator_optimizer = tpu_optimizer.CrossShardOptimizer( + training.GradientDescentOptimizer(1.0)) + cls._discriminator_optimizer = tpu_optimizer.CrossShardOptimizer( + training.GradientDescentOptimizer(1.0)) + + @parameterized.named_parameters( + ('joint_train', model_fn_lib.ModeKeys.TRAIN, True), + ('train_sequential', model_fn_lib.ModeKeys.TRAIN, False), + ('eval', model_fn_lib.ModeKeys.EVAL, None), + ('predict', model_fn_lib.ModeKeys.PREDICT, None)) + def test_get_estimator_spec(self, mode, joint_train): + with ops.Graph().as_default(): + self._gan_model = get_dummy_gan_model() + spec = estimator._get_estimator_spec( + mode, + self._gan_model, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + get_eval_metric_ops_fn=get_metrics, + generator_optimizer=self._generator_optimizer, + discriminator_optimizer=self._discriminator_optimizer, + joint_train=joint_train, + is_on_tpu=FLAGS.use_tpu, + gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1)) + + self.assertIsInstance(spec, tpu_estimator.TPUEstimatorSpec) + self.assertEqual(mode, spec.mode) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertEqual({'generated_data': self._gan_model.generated_data}, + spec.predictions) + elif mode == model_fn_lib.ModeKeys.TRAIN: + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.train_op) + self.assertIsNotNone(spec.training_hooks) + elif mode == model_fn_lib.ModeKeys.EVAL: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.eval_metrics) + + +class TPUGANEstimatorIntegrationTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super(TPUGANEstimatorIntegrationTest, self).setUp() + self._model_dir = tempfile.mkdtemp() + self._config = tpu_config.RunConfig(model_dir=self._model_dir) + + def tearDown(self): + super(TPUGANEstimatorIntegrationTest, self).tearDown() + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, prediction_size, + lr_decay=False, joint_train=True): + def make_opt(): + gstep = training_util.get_or_create_global_step() + lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) + return training.GradientDescentOptimizer(lr) + + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + est = estimator.TPUGANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=gopt, + discriminator_optimizer=dopt, + joint_train=joint_train, + get_eval_metric_ops_fn=get_metrics, + train_batch_size=4, + eval_batch_size=10, + predict_batch_size=8, + use_tpu=FLAGS.use_tpu, + config=self._config) + + # Train. + num_steps_train = 10 + est.train(train_input_fn, steps=num_steps_train) + + # Evaluate. + num_steps_eval = 2 + scores = est.evaluate(eval_input_fn, steps=num_steps_eval) + self.assertIn(ops.GraphKeys.GLOBAL_STEP, six.iterkeys(scores)) + self.assertIn('loss', six.iterkeys(scores)) + self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], + scores['loss']) + self.assertIn('mse_custom_metric', six.iterkeys(scores)) + + # Predict. + predictions = np.array([x['generated_data'] for x in + est.predict(predict_input_fn)]) + self.assertAllEqual(prediction_size, predictions.shape) + + @parameterized.named_parameters( + ('joint_train', True, False, False), + ('train_sequential', False, False, False), + ('lr_decay', False, True, False), + ('train_sequential_ds', False, False, True)) + def test_numpy_input_fn(self, joint_train, lr_decay, return_ds): + """Tests complete flow with numpy_input_fn.""" + input_dim = 4 + def train_input_fn(params): + data = np.zeros([input_dim], dtype=np.float32) + ds = (dataset_ops.Dataset + .from_tensors((data, data)) + .repeat() + .batch(params['batch_size'], drop_remainder=True)) + if return_ds: + return ds + else: + x, y = ds.make_one_shot_iterator().get_next() + return x, y + def eval_input_fn(params): + data = np.zeros([input_dim], dtype=np.float32) + ds = (dataset_ops.Dataset + .from_tensors((data, data)) + .repeat() + .batch(params['batch_size'], drop_remainder=True)) + if return_ds: + return ds + else: + x, y = ds.make_one_shot_iterator().get_next() + return x, y + predict_size = 10 + def predict_input_fn(params): + del params # unused + data = np.zeros([input_dim], dtype=np.float32) + ds = (dataset_ops.Dataset + .from_tensors(data) + .repeat(predict_size) + .batch(1, drop_remainder=True)) + return ds + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + prediction_size=[predict_size, input_dim], + lr_decay=lr_decay, + joint_train=joint_train) + + +class TPUGANEstimatorWarmStartTest(test.TestCase): + + def setUp(self): + self._model_dir = self.get_temp_dir() + self._config = tpu_config.RunConfig(model_dir=self._model_dir) + self.new_variable_name = 'new_var' + self.new_variable_value = [1.0, 2.0, 3.0] + + def tearDown(self): + writer_cache.FileWriterCache.clear() + + def _test_warm_start(self, warm_start_from=None): + """Tests whether WarmStartSettings work as intended.""" + def generator_with_new_variable(noise_dict, mode): + variable_scope.get_variable(name=self.new_variable_name, + initializer=self.new_variable_value, + trainable=True) + return generator_fn(noise_dict, mode) + + est = estimator.TPUGANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=training.GradientDescentOptimizer(1.0), + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + train_batch_size=4, + use_tpu=FLAGS.use_tpu, + config=self._config) + + def train_input_fn(params): + data = np.zeros([params['batch_size'], 4], dtype=np.float32) + return data, data + + est.train(train_input_fn, steps=1) + + est_warm = estimator.TPUGANEstimator( + generator_fn=generator_with_new_variable, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=training.GradientDescentOptimizer(1.0), + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + config=tpu_config.RunConfig( + model_dir=None if warm_start_from else self._model_dir), + train_batch_size=4, + use_tpu=FLAGS.use_tpu, + warm_start_from=warm_start_from) + + est_warm.train(train_input_fn, steps=1) + + return est_warm + + def test_warm_start_error(self): + """Test if exception when reloading different estimators.""" + with self.assertRaises(NotFoundError): + self._test_warm_start() + + def test_warm_start_success(self): + """Test if GANEstimator allows explicit warm start variable assignment.""" + # Regex matches all variable names in ckpt except for new_var. + var_regex = '^(?!.*%s.*)' % self.new_variable_name + warmstart = WarmStartSettings(ckpt_to_initialize_from=self._model_dir, + vars_to_warm_start=var_regex) + est_warm = self._test_warm_start(warm_start_from=warmstart) + full_variable_name = 'Generator/%s' % self.new_variable_name + self.assertIn(full_variable_name, est_warm.get_variable_names()) + equal_vals = np.array_equal(est_warm.get_variable_value(full_variable_name), + self.new_variable_value) + self.assertTrue(equal_vals) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/eval/__init__.py b/tensorflow/contrib/gan/python/eval/__init__.py index f86b8513053a45f9830411f7df2c32d1f36a97b2..92e9abf8a35de1999eb800e169f32220fe47f8cd 100644 --- a/tensorflow/contrib/gan/python/eval/__init__.py +++ b/tensorflow/contrib/gan/python/eval/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN evaluation module. +"""TF-GAN evaluation module. This module supports techniques such as Inception Score, Frechet Inception distance, and Sliced Wasserstein distance. diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py index 1c872626a957279132772ae27df7a66a2564e9a5..a52e899114b62cb29752f72aa59f142f4a428aa1 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN.""" +"""Model evaluation tools for TF-GAN.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index a71ee53311c1c057a5b41be0331bf56ce1a82f74..ff19ce2f78e9c86400089e454c88450f01c41764 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN. +"""Model evaluation tools for TF-GAN. These methods come from https://arxiv.org/abs/1606.03498, https://arxiv.org/abs/1706.08500, and https://arxiv.org/abs/1801.01401. @@ -41,9 +41,9 @@ from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops @@ -346,7 +346,7 @@ def classifier_score(images, classifier_fn, num_batches=1): images, num_or_size_splits=num_batches) # Compute the classifier splits using the memory-efficient `map_fn`. - logits = functional_ops.map_fn( + logits = map_fn.map_fn( fn=classifier_fn, elems=array_ops.stack(generated_images_list), parallel_iterations=1, @@ -387,7 +387,7 @@ def classifier_score_from_logits(logits): # Use maximum precision for best results. logits_dtype = logits.dtype if logits_dtype != dtypes.float64: - logits = math_ops.to_double(logits) + logits = math_ops.cast(logits, dtypes.float64) p = nn_ops.softmax(logits) q = math_ops.reduce_mean(p, axis=0) @@ -505,12 +505,12 @@ def frechet_classifier_distance(real_images, # Compute the activations using the memory-efficient `map_fn`. def compute_activations(elems): - return functional_ops.map_fn(fn=classifier_fn, - elems=elems, - parallel_iterations=1, - back_prop=False, - swap_memory=True, - name='RunClassifier') + return map_fn.map_fn(fn=classifier_fn, + elems=elems, + parallel_iterations=1, + back_prop=False, + swap_memory=True, + name='RunClassifier') real_a = compute_activations(real_imgs) gen_a = compute_activations(generated_imgs) @@ -562,8 +562,8 @@ def mean_only_frechet_classifier_distance_from_activations( activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: - real_activations = math_ops.to_double(real_activations) - generated_activations = math_ops.to_double(generated_activations) + real_activations = math_ops.cast(real_activations, dtypes.float64) + generated_activations = math_ops.cast(generated_activations, dtypes.float64) # Compute means of activations. m = math_ops.reduce_mean(real_activations, 0) @@ -623,8 +623,8 @@ def diagonal_only_frechet_classifier_distance_from_activations( activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: - real_activations = math_ops.to_double(real_activations) - generated_activations = math_ops.to_double(generated_activations) + real_activations = math_ops.cast(real_activations, dtypes.float64) + generated_activations = math_ops.cast(generated_activations, dtypes.float64) # Compute mean and covariance matrices of activations. m, var = nn_impl.moments(real_activations, axes=[0]) @@ -698,15 +698,16 @@ def frechet_classifier_distance_from_activations(real_activations, activations_dtype = real_activations.dtype if activations_dtype != dtypes.float64: - real_activations = math_ops.to_double(real_activations) - generated_activations = math_ops.to_double(generated_activations) + real_activations = math_ops.cast(real_activations, dtypes.float64) + generated_activations = math_ops.cast(generated_activations, dtypes.float64) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) m_w = math_ops.reduce_mean(generated_activations, 0) - num_examples_real = math_ops.to_double(array_ops.shape(real_activations)[0]) - num_examples_generated = math_ops.to_double( - array_ops.shape(generated_activations)[0]) + num_examples_real = math_ops.cast( + array_ops.shape(real_activations)[0], dtypes.float64) + num_examples_generated = math_ops.cast( + array_ops.shape(generated_activations)[0], dtypes.float64) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m @@ -794,9 +795,9 @@ def kernel_classifier_distance(real_images, on a classifier. num_classifier_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. - max_estimator_block_size: integer, default 1024. The distance estimator - splits samples into blocks for computational efficiency. Larger values are - more computationally expensive but decrease the variance of the distance + max_block_size: integer, default 1024. The distance estimator splits samples + into blocks for computational efficiency. Larger values are more + computationally expensive but decrease the variance of the distance estimate. dtype: if not None, coerce activations to this dtype before computations. @@ -871,9 +872,9 @@ def kernel_classifier_distance_and_std(real_images, on a classifier. num_classifier_batches: Number of batches to split images in to in order to efficiently run them through the classifier network. - max_estimator_block_size: integer, default 1024. The distance estimator - splits samples into blocks for computational efficiency. Larger values are - more computationally expensive but decrease the variance of the distance + max_block_size: integer, default 1024. The distance estimator splits samples + into blocks for computational efficiency. Larger values are more + computationally expensive but decrease the variance of the distance estimate. Having a smaller block size also gives a better estimate of the standard error. dtype: if not None, coerce activations to this dtype before computations. @@ -894,7 +895,7 @@ def kernel_classifier_distance_and_std(real_images, # Compute the activations using the memory-efficient `map_fn`. def compute_activations(elems): - return functional_ops.map_fn( + return map_fn.map_fn( fn=classifier_fn, elems=elems, parallel_iterations=1, @@ -910,7 +911,7 @@ def kernel_classifier_distance_and_std(real_images, gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) return kernel_classifier_distance_and_std_from_activations( - real_a, gen_a, max_block_size=max_block_size) + real_a, gen_a, max_block_size, dtype) kernel_inception_distance_and_std = functools.partial( @@ -967,14 +968,14 @@ def kernel_classifier_distance_from_activations(real_activations, into blocks for computational efficiency. Larger values are more computationally expensive but decrease the variance of the distance estimate. - dtype: if not None, coerce activations to this dtype before computations. + dtype: If not None, coerce activations to this dtype before computations. Returns: The Kernel Inception Distance. A floating-point scalar of the same type as the output of the activations. """ return kernel_classifier_distance_and_std_from_activations( - real_activations, generated_activations, max_block_size=max_block_size)[0] + real_activations, generated_activations, max_block_size, dtype)[0] def kernel_classifier_distance_and_std_from_activations(real_activations, @@ -1029,7 +1030,7 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, computationally expensive but decrease the variance of the distance estimate. Having a smaller block size also gives a better estimate of the standard error. - dtype: if not None, coerce activations to this dtype before computations. + dtype: If not None, coerce activations to this dtype before computations. Returns: The Kernel Inception Distance. A floating-point scalar of the same type @@ -1080,7 +1081,7 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, dim = math_ops.cast(real_activations.shape[1], dtype) def compute_kid_block(i): - 'Compute the ith block of the KID estimate.' + """Computes the ith block of the KID estimate.""" r_s = inds_r[i] r_e = inds_r[i + 1] r = real_activations[r_s:r_e] @@ -1098,7 +1099,7 @@ def kernel_classifier_distance_and_std_from_activations(real_activations, (math_ops.reduce_sum(k_rr) - math_ops.trace(k_rr)) / (m * (m - 1)) + (math_ops.reduce_sum(k_gg) - math_ops.trace(k_gg)) / (n * (n - 1))) - ests = functional_ops.map_fn( + ests = map_fn.map_fn( compute_kid_block, math_ops.range(n_blocks), dtype=dtype, back_prop=False) mn = math_ops.reduce_mean(ests) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index dbff1d2a367e10adc607dafb4c571bb3607a3963..bd17571a0535a3c8e9dfee24a8da16eb2e72f165 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN classifier_metrics.""" +"""Tests for TF-GAN classifier_metrics.""" from __future__ import absolute_import from __future__ import division @@ -234,7 +234,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): else: logits = classifier_metrics.run_inception(img, _get_dummy_graphdef()) - self.assertTrue(isinstance(logits, ops.Tensor)) + self.assertIsInstance(logits, ops.Tensor) logits.shape.assert_is_compatible_with([batch_size, 1001]) # Check that none of the model variables are trainable. @@ -258,7 +258,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): img, _get_dummy_graphdef(), output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) - self.assertTrue(isinstance(pool, ops.Tensor)) + self.assertIsInstance(pool, ops.Tensor) pool.shape.assert_is_compatible_with([batch_size, 2048]) # Check that none of the model variables are trainable. @@ -276,8 +276,8 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): classifier_metrics.INCEPTION_FINAL_POOL ]) - self.assertTrue(isinstance(logits, ops.Tensor)) - self.assertTrue(isinstance(pool, ops.Tensor)) + self.assertIsInstance(logits, ops.Tensor) + self.assertIsInstance(pool, ops.Tensor) logits.shape.assert_is_compatible_with([batch_size, 1001]) pool.shape.assert_is_compatible_with([batch_size, 2048]) @@ -290,7 +290,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): classifier_metrics.inception_score, array_ops.zeros([6, 299, 299, 3]), num_batches=3) - self.assertTrue(isinstance(score, ops.Tensor)) + self.assertIsInstance(score, ops.Tensor) score.shape.assert_has_rank(0) # Check that none of the model variables are trainable. @@ -302,7 +302,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): distance = _run_with_mock( classifier_metrics.frechet_inception_distance, img, img) - self.assertTrue(isinstance(distance, ops.Tensor)) + self.assertIsInstance(distance, ops.Tensor) distance.shape.assert_has_rank(0) # Check that none of the model variables are trainable. @@ -314,7 +314,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): distance = _run_with_mock(classifier_metrics.kernel_inception_distance, img, img) - self.assertTrue(isinstance(distance, ops.Tensor)) + self.assertIsInstance(distance, ops.Tensor) distance.shape.assert_has_rank(0) # Check that none of the model variables are trainable. diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py index 523968bed91f1021ae629bf52c405cf5c2d7b917..326fcb3cdbf2eda66207f134cd2926f09a216a99 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Model evaluation tools for TFGAN.""" +"""Model evaluation tools for TF-GAN.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/summaries.py b/tensorflow/contrib/gan/python/eval/python/summaries.py index ecfdb39499b1e824e02415c0db1de3157e4f3216..1b202dfc97304ddc7ced42d65366aaf419439392 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common TFGAN summaries.""" +"""Common TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index f9995bb19d0d09eaf6fd96d039b0bba1d3a7055c..c7bbd65bbff41c25327733ae1f17a090fb69cb52 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common TFGAN summaries.""" +"""Common TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division @@ -22,7 +22,7 @@ from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.eval.python import eval_utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import util as loss_util @@ -261,7 +261,7 @@ def add_stargan_image_summaries(stargan_model, summary.image( 'stargan_image_generation', - functional_ops.map_fn( + map_fn.map_fn( _build_image, stargan_model.input_data[:num_images], parallel_iterations=num_images, diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 54a6f8d4d9086ad7fc8db31032677628561e48e8..53fc7cb8ede698c2d8590c7fd3016a884cef9be9 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for TFGAN summaries.""" +"""Tests for TF-GAN summaries.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py index 4816daf760143af9f1502873b123ffad8e5ec8ce..410c3a02052cd3a07a36a0ba332a80b3c2705d89 100644 --- a/tensorflow/contrib/gan/python/features/__init__.py +++ b/tensorflow/contrib/gan/python/features/__init__.py @@ -27,11 +27,13 @@ from __future__ import print_function from tensorflow.contrib.gan.python.features.python import clip_weights from tensorflow.contrib.gan.python.features.python import conditioning_utils from tensorflow.contrib.gan.python.features.python import random_tensor_pool +from tensorflow.contrib.gan.python.features.python import spectral_normalization from tensorflow.contrib.gan.python.features.python import virtual_batchnorm from tensorflow.contrib.gan.python.features.python.clip_weights import * from tensorflow.contrib.gan.python.features.python.conditioning_utils import * from tensorflow.contrib.gan.python.features.python.random_tensor_pool import * +from tensorflow.contrib.gan.python.features.python.spectral_normalization import * from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * # pylint: enable=unused-import,wildcard-import @@ -40,5 +42,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = clip_weights.__all__ _allowed_symbols += conditioning_utils.__all__ _allowed_symbols += random_tensor_pool.__all__ +_allowed_symbols += spectral_normalization.__all__ _allowed_symbols += virtual_batchnorm.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py index e2594faf85bcf91cbe09f266e4d4211d20bdee17..364fa4eb461c62784803f0c309e3b7c5855df199 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py @@ -64,6 +64,9 @@ def condition_tensor(tensor, conditioning): """ tensor.shape[1:].assert_is_fully_defined() num_features = tensor.shape[1:].num_elements() + if conditioning.shape.ndims < 2: + raise ValueError('conditioning must be at least 2D, but saw shape: %s' + % conditioning.shape) mapped_conditioning = layers.linear( layers.flatten(conditioning), num_features) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py index 0aad769793761be69ee9d1e3416e44c7b3d8cea0..f5c7d53cf2c9aa08ba0074950983ef3ecd90168b 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py @@ -45,7 +45,7 @@ class ConditioningUtilsTest(test.TestCase): array_ops.placeholder(dtypes.float32, (5, None)), array_ops.placeholder(dtypes.float32, (5, 1))) - with self.assertRaisesRegexp(ValueError, 'expected min_ndim=2'): + with self.assertRaisesRegexp(ValueError, 'at least 2D'): conditioning_utils.condition_tensor( array_ops.placeholder(dtypes.float32, (5, 2)), array_ops.placeholder(dtypes.float32, (5))) diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..54d3d0a218dec3588844333cd47e1f92489d8df9 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-like layers and utilities that implement Spectral Normalization. + +Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, +et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.features.python.spectral_normalization_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = spectral_normalization_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc653f0a7907f407e66add5537d1e0a5adb6d8b --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py @@ -0,0 +1,315 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-like layers and utilities that implement Spectral Normalization. + +Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, +et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import numbers +import re + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import base_layer_utils as keras_base_layer_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging + +__all__ = [ + 'compute_spectral_norm', 'spectral_normalize', 'spectral_norm_regularizer', + 'spectral_normalization_custom_getter', 'keras_spectral_normalization' +] + +# tf.bfloat16 should work, but tf.matmul converts those to tf.float32 which then +# can't directly be assigned back to the tf.bfloat16 variable. +_OK_DTYPES_FOR_SPECTRAL_NORM = (dtypes.float16, dtypes.float32, dtypes.float64) +_PERSISTED_U_VARIABLE_SUFFIX = 'spectral_norm_u' + + +def compute_spectral_norm(w_tensor, power_iteration_rounds=1, name=None): + """Estimates the largest singular value in the weight tensor. + + Args: + w_tensor: The weight matrix whose spectral norm should be computed. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + name: An optional scope name. + + Returns: + The largest singular value (the spectral norm) of w. + """ + with variable_scope.variable_scope(name, 'spectral_norm'): + # The paper says to flatten convnet kernel weights from + # (C_out, C_in, KH, KW) to (C_out, C_in * KH * KW). But TensorFlow's Conv2D + # kernel weight shape is (KH, KW, C_in, C_out), so it should be reshaped to + # (KH * KW * C_in, C_out), and similarly for other layers that put output + # channels as last dimension. + # n.b. this means that w here is equivalent to w.T in the paper. + w = array_ops.reshape(w_tensor, (-1, w_tensor.get_shape()[-1])) + + # Persisted approximation of first left singular vector of matrix `w`. + u_var = variable_scope.get_variable( + _PERSISTED_U_VARIABLE_SUFFIX, + shape=(w.shape[0], 1), + dtype=w.dtype, + initializer=init_ops.random_normal_initializer(), + trainable=False) + u = u_var + + # Use power iteration method to approximate spectral norm. + for _ in range(power_iteration_rounds): + # `v` approximates the first right singular vector of matrix `w`. + v = nn.l2_normalize(math_ops.matmul(array_ops.transpose(w), u)) + u = nn.l2_normalize(math_ops.matmul(w, v)) + + # Update persisted approximation. + with ops.control_dependencies([u_var.assign(u, name='update_u')]): + u = array_ops.identity(u) + + u = array_ops.stop_gradient(u) + v = array_ops.stop_gradient(v) + + # Largest singular value of `w`. + spectral_norm = math_ops.matmul( + math_ops.matmul(array_ops.transpose(u), w), v) + spectral_norm.shape.assert_is_fully_defined() + spectral_norm.shape.assert_is_compatible_with([1, 1]) + + return spectral_norm[0][0] + + +def spectral_normalize(w, power_iteration_rounds=1, name=None): + """Normalizes a weight matrix by its spectral norm. + + Args: + w: The weight matrix to be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + name: An optional scope name. + + Returns: + A normalized weight matrix tensor. + """ + with variable_scope.variable_scope(name, 'spectral_normalize'): + w_normalized = w / compute_spectral_norm( + w, power_iteration_rounds=power_iteration_rounds) + return array_ops.reshape(w_normalized, w.get_shape()) + + +def spectral_norm_regularizer(scale, power_iteration_rounds=1, scope=None): + """Returns a functions that can be used to apply spectral norm regularization. + + Small spectral norms enforce a small Lipschitz constant, which is necessary + for Wasserstein GANs. + + Args: + scale: A scalar multiplier. 0.0 disables the regularizer. + power_iteration_rounds: The number of iterations of the power method to + perform. A higher number yeilds a better approximation. + scope: An optional scope name. + + Returns: + A function with the signature `sn(weights)` that applies spectral norm + regularization. + + Raises: + ValueError: If scale is negative or if scale is not a float. + """ + if isinstance(scale, numbers.Integral): + raise ValueError('scale cannot be an integer: %s' % scale) + if isinstance(scale, numbers.Real): + if scale < 0.0: + raise ValueError( + 'Setting a scale less than 0 on a regularizer: %g' % scale) + if scale == 0.0: + logging.info('Scale of 0 disables regularizer.') + return lambda _: None + + def sn(weights, name=None): + """Applies spectral norm regularization to weights.""" + with ops.name_scope(scope, 'SpectralNormRegularizer', [weights]) as name: + scale_t = ops.convert_to_tensor( + scale, dtype=weights.dtype.base_dtype, name='scale') + return math_ops.multiply( + scale_t, + compute_spectral_norm( + weights, power_iteration_rounds=power_iteration_rounds), + name=name) + + return sn + + +def _default_name_filter(name): + """A filter function to identify common names of weight variables. + + Args: + name: The variable name. + + Returns: + Whether `name` is a standard name for a weight/kernel variables used in the + Keras, tf.layers, tf.contrib.layers or tf.contrib.slim libraries. + """ + match = re.match(r'(.*\/)?(depthwise_|pointwise_)?(weights|kernel)$', name) + return match is not None + + +def spectral_normalization_custom_getter(name_filter=_default_name_filter, + power_iteration_rounds=1): + """Custom getter that performs Spectral Normalization on a weight tensor. + + Specifically it divides the weight tensor by its largest singular value. This + is intended to stabilize GAN training, by making the discriminator satisfy a + local 1-Lipschitz constraint. + + Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan]. + + [sn-gan]: https://openreview.net/forum?id=B1QRgziT- + + To reproduce an SN-GAN, apply this custom_getter to every weight tensor of + your discriminator. The last dimension of the weight tensor must be the number + of output channels. + + Apply this to layers by supplying this as the `custom_getter` of a + `tf.variable_scope`. For example: + + with tf.variable_scope('discriminator', + custom_getter=spectral_norm_getter()): + net = discriminator_fn(net) + + IMPORTANT: Keras does not respect the custom_getter supplied by the + VariableScope, so Keras users should use `keras_spectral_normalization` + instead of (or in addition to) this approach. + + It is important to carefully select to which weights you want to apply + Spectral Normalization. In general you want to normalize the kernels of + convolution and dense layers, but you do not want to normalize biases. You + also want to avoid normalizing batch normalization (and similar) variables, + but in general such layers play poorly with Spectral Normalization, since the + gamma can cancel out the normalization in other layers. By default we supply a + filter that matches the kernel variable names of the dense and convolution + layers of the tf.layers, tf.contrib.layers, tf.keras and tf.contrib.slim + libraries. If you are using anything else you'll need a custom `name_filter`. + + This custom getter internally creates a variable used to compute the spectral + norm by power iteration. It will update every time the variable is accessed, + which means the normalized discriminator weights may change slightly whilst + training the generator. Whilst unusual, this matches how the paper's authors + implement it, and in general additional rounds of power iteration can't hurt. + + Args: + name_filter: Optionally, a method that takes a Variable name as input and + returns whether this Variable should be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform per step. A higher number yeilds a better approximation of the + true spectral norm. + + Returns: + A custom getter function that applies Spectral Normalization to all + Variables whose names match `name_filter`. + + Raises: + ValueError: If name_filter is not callable. + """ + if not callable(name_filter): + raise ValueError('name_filter must be callable') + + def _internal_getter(getter, name, *args, **kwargs): + """A custom getter function that applies Spectral Normalization. + + Args: + getter: The true getter to call. + name: Name of new/existing variable, in the same format as + tf.get_variable. + *args: Other positional arguments, in the same format as tf.get_variable. + **kwargs: Keyword arguments, in the same format as tf.get_variable. + + Returns: + The return value of `getter(name, *args, **kwargs)`, spectrally + normalized. + + Raises: + ValueError: If used incorrectly, or if `dtype` is not supported. + """ + if not name_filter(name): + return getter(name, *args, **kwargs) + + if name.endswith(_PERSISTED_U_VARIABLE_SUFFIX): + raise ValueError( + 'Cannot apply Spectral Normalization to internal variables created ' + 'for Spectral Normalization. Tried to normalized variable [%s]' % + name) + + if kwargs['dtype'] not in _OK_DTYPES_FOR_SPECTRAL_NORM: + raise ValueError('Disallowed data type {}'.format(kwargs['dtype'])) + + # This layer's weight Variable/PartitionedVariable. + w_tensor = getter(name, *args, **kwargs) + + if len(w_tensor.get_shape()) < 2: + raise ValueError( + 'Spectral norm can only be applied to multi-dimensional tensors') + + return spectral_normalize( + w_tensor, + power_iteration_rounds=power_iteration_rounds, + name=(name + '/spectral_normalize')) + + return _internal_getter + + +@contextlib.contextmanager +def keras_spectral_normalization(name_filter=_default_name_filter, + power_iteration_rounds=1): + """A context manager that enables Spectral Normalization for Keras. + + Keras doesn't respect the `custom_getter` in the VariableScope, so this is a + bit of a hack to make things work. + + Usage: + with keras_spectral_normalization(): + net = discriminator_fn(net) + + Args: + name_filter: Optionally, a method that takes a Variable name as input and + returns whether this Variable should be normalized. + power_iteration_rounds: The number of iterations of the power method to + perform per step. A higher number yeilds a better approximation of the + true spectral norm. + + Yields: + A context manager that wraps the standard Keras variable creation method + with the `spectral_normalization_custom_getter`. + """ + original_make_variable = keras_base_layer_utils.make_variable + sn_getter = spectral_normalization_custom_getter( + name_filter=name_filter, power_iteration_rounds=power_iteration_rounds) + + def make_variable_wrapper(name, *args, **kwargs): + return sn_getter(original_make_variable, name, *args, **kwargs) + + keras_base_layer_utils.make_variable = make_variable_wrapper + + yield + + keras_base_layer_utils.make_variable = original_make_variable diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea21f70ec01950cfef5e4fa851c78b219d6062f --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py @@ -0,0 +1,354 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for features.spectral_normalization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import slim +from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl as spectral_normalization +from tensorflow.contrib.layers.python.layers import layers as contrib_layers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.layers import convolutional as keras_convolutional +from tensorflow.python.keras.layers import core as keras_core +from tensorflow.python.layers import convolutional as layers_convolutional +from tensorflow.python.layers import core as layers_core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class SpectralNormalizationTest(test.TestCase): + + def testComputeSpectralNorm(self): + weights = variable_scope.get_variable( + 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) + weights = math_ops.multiply(weights, 10.0) + s = linalg_ops.svd( + array_ops.reshape(weights, [-1, weights.shape[-1]]), compute_uv=False) + true_sn = s[..., 0] + estimated_sn = spectral_normalization.compute_spectral_norm(weights) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + np_true_sn = sess.run(true_sn) + for i in range(50): + est = sess.run(estimated_sn) + if i < 1: + np_est_1 = est + if i < 4: + np_est_5 = est + if i < 9: + np_est_10 = est + np_est_50 = est + + # Check that the estimate improves with more iterations. + self.assertAlmostEqual(np_true_sn, np_est_50, 0) + self.assertGreater( + abs(np_true_sn - np_est_10), abs(np_true_sn - np_est_50)) + self.assertGreater( + abs(np_true_sn - np_est_5), abs(np_true_sn - np_est_10)) + self.assertGreater(abs(np_true_sn - np_est_1), abs(np_true_sn - np_est_5)) + + def testSpectralNormalize(self): + weights = variable_scope.get_variable( + 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) + weights = math_ops.multiply(weights, 10.0) + normalized_weights = spectral_normalization.spectral_normalize( + weights, power_iteration_rounds=1) + + unnormalized_sigma = linalg_ops.svd( + array_ops.reshape(weights, [-1, weights.shape[-1]]), + compute_uv=False)[..., 0] + normalized_sigma = linalg_ops.svd( + array_ops.reshape(normalized_weights, [-1, weights.shape[-1]]), + compute_uv=False)[..., 0] + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + s0 = sess.run(unnormalized_sigma) + + for i in range(50): + sigma = sess.run(normalized_sigma) + if i < 1: + s1 = sigma + if i < 5: + s5 = sigma + if i < 10: + s10 = sigma + s50 = sigma + + self.assertAlmostEqual(1., s50, 0) + self.assertGreater(abs(s10 - 1.), abs(s50 - 1.)) + self.assertGreater(abs(s5 - 1.), abs(s10 - 1.)) + self.assertGreater(abs(s1 - 1.), abs(s5 - 1.)) + self.assertGreater(abs(s0 - 1.), abs(s1 - 1.)) + + def _testLayerHelper(self, build_layer_fn, w_shape, b_shape, is_keras=False): + x = array_ops.placeholder(dtypes.float32, shape=[2, 10, 10, 3]) + + w_initial = np.random.randn(*w_shape) * 10 + w_initializer = init_ops.constant_initializer(w_initial) + b_initial = np.random.randn(*b_shape) + b_initializer = init_ops.constant_initializer(b_initial) + + if is_keras: + context_manager = spectral_normalization.keras_spectral_normalization() + else: + getter = spectral_normalization.spectral_normalization_custom_getter() + context_manager = variable_scope.variable_scope('', custom_getter=getter) + + with context_manager: + (net, + expected_normalized_vars, expected_not_normalized_vars) = build_layer_fn( + x, w_initializer, b_initializer) + + x_data = np.random.rand(*x.shape) + + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + + # Before running a forward pass we still expect the variables values to + # differ from the initial value because of the normalizer. + w_befores = [] + for name, var in expected_normalized_vars.items(): + w_before = sess.run(var) + w_befores.append(w_before) + self.assertFalse( + np.allclose(w_initial, w_before), + msg=('%s appears not to be normalized. Before: %s After: %s' % + (name, w_initial, w_before))) + + # Not true for the unnormalized variables. + for name, var in expected_not_normalized_vars.items(): + b_before = sess.run(var) + self.assertTrue( + np.allclose(b_initial, b_before), + msg=('%s appears to be unexpectedly normalized. ' + 'Before: %s After: %s' % (name, b_initial, b_before))) + + # Run a bunch of forward passes. + for _ in range(1000): + _ = sess.run(net, feed_dict={x: x_data}) + + # We expect this to have improved the estimate of the spectral norm, + # which should have changed the variable values and brought them close + # to the true Spectral Normalized values. + _, s, _ = np.linalg.svd(w_initial.reshape([-1, 3])) + exactly_normalized = w_initial / s[0] + for w_before, (name, var) in zip(w_befores, + expected_normalized_vars.items()): + w_after = sess.run(var) + self.assertFalse( + np.allclose(w_before, w_after, rtol=1e-8, atol=1e-8), + msg=('%s did not improve over many iterations. ' + 'Before: %s After: %s' % (name, w_before, w_after))) + self.assertAllClose( + exactly_normalized, + w_after, + rtol=1e-4, + atol=1e-4, + msg=('Estimate of spectral norm for %s was innacurate. ' + 'Normalized matrices do not match.' + 'Estimate: %s Actual: %s' % (name, w_after, + exactly_normalized))) + + def testConv2D_Layers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + layer = layers_convolutional.Conv2D( + filters=3, + kernel_size=3, + padding='same', + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'tf.layers.Conv2d.kernel': layer.kernel} + expected_not_normalized_vars = {'tf.layers.Conv2d.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_ContribLayers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['CONTRIB_LAYERS_CONV2D_WEIGHTS'], + 'biases': ['CONTRIB_LAYERS_CONV2D_BIASES'] + } + net = contrib_layers.conv2d( + x, + 3, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'contrib.layers.conv2d.weights': weight_vars[0] + } + expected_not_normalized_vars = { + 'contrib.layers.conv2d.bias': bias_vars[0] + } + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_Slim(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['SLIM_CONV2D_WEIGHTS'], + 'biases': ['SLIM_CONV2D_BIASES'] + } + net = slim.conv2d( + x, + 3, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('SLIM_CONV2D_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('SLIM_CONV2D_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = {'slim.conv2d.weights': weight_vars[0]} + expected_not_normalized_vars = {'slim.conv2d.bias': bias_vars[0]} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) + + def testConv2D_Keras(self): + + def build_layer_fn(x, w_initializer, b_initializer): + layer = keras_convolutional.Conv2D( + filters=3, + kernel_size=3, + padding='same', + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'keras.layers.Conv2d.kernel': layer.kernel} + expected_not_normalized_vars = {'keras.layers.Conv2d.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,), is_keras=True) + + def testFC_Layers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + x = layers_core.Flatten()(x) + layer = layers_core.Dense( + units=3, + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'tf.layers.Dense.kernel': layer.kernel} + expected_not_normalized_vars = {'tf.layers.Dense.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_ContribLayers(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['CONTRIB_LAYERS_FC_WEIGHTS'], + 'biases': ['CONTRIB_LAYERS_FC_BIASES'] + } + x = contrib_layers.flatten(x) + net = contrib_layers.fully_connected( + x, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('CONTRIB_LAYERS_FC_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('CONTRIB_LAYERS_FC_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'contrib.layers.fully_connected.weights': weight_vars[0] + } + expected_not_normalized_vars = { + 'contrib.layers.fully_connected.bias': bias_vars[0] + } + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_Slim(self): + + def build_layer_fn(x, w_initializer, b_initializer): + var_collection = { + 'weights': ['SLIM_FC_WEIGHTS'], + 'biases': ['SLIM_FC_BIASES'] + } + x = slim.flatten(x) + net = slim.fully_connected( + x, + 3, + weights_initializer=w_initializer, + biases_initializer=b_initializer, + variables_collections=var_collection) + weight_vars = ops.get_collection('SLIM_FC_WEIGHTS') + self.assertEquals(1, len(weight_vars)) + bias_vars = ops.get_collection('SLIM_FC_BIASES') + self.assertEquals(1, len(bias_vars)) + expected_normalized_vars = { + 'slim.fully_connected.weights': weight_vars[0] + } + expected_not_normalized_vars = {'slim.fully_connected.bias': bias_vars[0]} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,)) + + def testFC_Keras(self): + + def build_layer_fn(x, w_initializer, b_initializer): + x = keras_core.Flatten()(x) + layer = keras_core.Dense( + units=3, + kernel_initializer=w_initializer, + bias_initializer=b_initializer) + net = layer.apply(x) + expected_normalized_vars = {'keras.layers.Dense.kernel': layer.kernel} + expected_not_normalized_vars = {'keras.layers.Dense.bias': layer.bias} + + return net, expected_normalized_vars, expected_not_normalized_vars + + self._testLayerHelper(build_layer_fn, (300, 3), (3,), is_keras=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index a0a86c6337eefa756a209635faa70db686a36247..1f1ae2df4d6def618e86aced3296ac89c836eab7 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -28,7 +28,7 @@ wasserstein_gradient_penalty All losses must be able to accept 1D or 2D Tensors, so as to be compatible with patchGAN style losses (https://arxiv.org/abs/1611.07004). -To make these losses usable in the TFGAN framework, please create a tuple +To make these losses usable in the TF-GAN framework, please create a tuple version of the losses with `losses_utils.py`. """ @@ -38,6 +38,7 @@ from __future__ import print_function from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -69,6 +70,10 @@ __all__ = [ ] +def _to_float(tensor): + return math_ops.cast(tensor, dtypes.float32) + + # Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875). def wasserstein_generator_loss( discriminator_gen_outputs, @@ -98,7 +103,7 @@ def wasserstein_generator_loss( """ with ops.name_scope(scope, 'generator_wasserstein_loss', ( discriminator_gen_outputs, weights)) as scope: - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) loss = - discriminator_gen_outputs loss = losses.compute_weighted_loss( @@ -144,8 +149,8 @@ def wasserstein_discriminator_loss( with ops.name_scope(scope, 'discriminator_wasserstein_loss', ( discriminator_real_outputs, discriminator_gen_outputs, real_weights, generated_weights)) as scope: - discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs) - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_real_outputs = _to_float(discriminator_real_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) discriminator_real_outputs.shape.assert_is_compatible_with( discriminator_gen_outputs.shape) @@ -320,7 +325,7 @@ def wasserstein_gradient_penalty( generated_data: Output of the generator. generator_inputs: Exact argument to pass to the generator, which is used as optional conditioning to the discriminator. - discriminator_fn: A discriminator function that conforms to TFGAN API. + discriminator_fn: A discriminator function that conforms to TF-GAN API. discriminator_scope: If not `None`, reuse discriminators from this scope. epsilon: A small positive number added for numerical stability when computing the gradient norm. @@ -647,7 +652,7 @@ def least_squares_generator_loss( """ with ops.name_scope(scope, 'lsq_generator_loss', (discriminator_gen_outputs, real_label)) as scope: - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) loss = math_ops.squared_difference( discriminator_gen_outputs, real_label) / 2.0 loss = losses.compute_weighted_loss( @@ -702,8 +707,8 @@ def least_squares_discriminator_loss( """ with ops.name_scope(scope, 'lsq_discriminator_loss', (discriminator_gen_outputs, real_label)) as scope: - discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs) - discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_real_outputs = _to_float(discriminator_real_outputs) + discriminator_gen_outputs = _to_float(discriminator_gen_outputs) discriminator_real_outputs.shape.assert_is_compatible_with( discriminator_gen_outputs.shape) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index 221c70c38bd432a6be7f6cda9c6700aa2255821f..76e57df7f646547037b3461ac44f7ee5b971406c 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN utilities for loss functions that accept GANModel namedtuples. +"""TF-GAN utilities for loss functions that accept GANModel namedtuples. The losses and penalties in this file all correspond to losses in `losses_impl.py`. Losses in that file take individual arguments, whereas in this diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 969b68449d9c82f9f9144a8657cd8932b38fd0f7..73dfee4fdeec87cf0bac5eb675fd02a64a9ad7f5 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Named tuples for TFGAN. +"""Named tuples for TF-GAN. -TFGAN training occurs in four steps, and each step communicates with the next -step via one of these named tuples. At each step, you can either use a TFGAN +TF-GAN training occurs in four steps, and each step communicates with the next +step via one of these named tuples. At each step, you can either use a TF-GAN helper function in `train.py`, or you can manually construct a tuple. """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 4c7bee41b33ce1fee46d374ca5fd1c0b603762f9..f36a5d346e0f27fbbc480e876380db51ed559c09 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The TFGAN project provides a lightweight GAN training/testing framework. +"""The TF-GAN project provides a lightweight GAN training/testing framework. This file contains the core helper functions to create and train a GAN model. See the README or examples in `tensorflow_models` for details on how to use. -TFGAN training occurs in four steps: +TF-GAN training occurs in four steps: 1) Create a model 2) Add a loss 3) Create train ops @@ -645,9 +645,10 @@ def gan_loss( type(model)) # Optionally create pooled model. - pooled_model = ( - _tensor_pool_adjusted_model(model, tensor_pool_fn) - if tensor_pool_fn else model) + if tensor_pool_fn: + pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn) + else: + pooled_model = model # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) @@ -665,10 +666,11 @@ def gan_loss( if _use_aux_loss(mutual_information_penalty_weight): gen_info_loss = tfgan_losses.mutual_information_penalty( model, add_summaries=add_summaries) - dis_info_loss = ( - gen_info_loss - if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty( - pooled_model, add_summaries=add_summaries)) + if tensor_pool_fn is None: + dis_info_loss = gen_info_loss + else: + dis_info_loss = tfgan_losses.mutual_information_penalty( + pooled_model, add_summaries=add_summaries) gen_loss += mutual_information_penalty_weight * gen_info_loss dis_loss += mutual_information_penalty_weight * dis_info_loss if _use_aux_loss(aux_cond_generator_weight): @@ -929,7 +931,7 @@ def gan_train_ops( **kwargs): """Returns GAN train ops. - The highest-level call in TFGAN. It is composed of functions that can also + The highest-level call in TF-GAN. It is composed of functions that can also be called, should a user require more control over some part of the GAN training process. diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index e534fdc17749974ebe713c2730682bea6d7a85e4..bf8b66dcfa5e44a03107cdf1ef8b04e1dbff4a9c 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -17,11 +17,6 @@ filegroup( ]), ) -load( - "//tensorflow:tensorflow.bzl", - "tf_cuda_library", -) - # For platform specific build config load( "//tensorflow/core:platform/default/build_config.bzl", @@ -37,7 +32,7 @@ tf_proto_library_cc( ], ) -tf_cuda_library( +cc_library( name = "gdr_memory_manager", srcs = ["gdr_memory_manager.cc"], hdrs = ["gdr_memory_manager.h"], @@ -58,7 +53,7 @@ tf_cuda_library( ], ) -tf_cuda_library( +cc_library( name = "gdr_worker", srcs = ["gdr_worker.cc"], hdrs = ["gdr_worker.h"], @@ -66,7 +61,6 @@ tf_cuda_library( ":gdr_memory_manager", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:graph_mgr", @@ -100,15 +94,37 @@ cc_library( ], ) +cc_library( + name = "gdr_collective_executor_mgr", + srcs = ["gdr_collective_executor_mgr.cc"], + hdrs = ["gdr_collective_executor_mgr.h"], + deps = [ + ":gdr_memory_manager", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:cancellable_call", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", + "//tensorflow/core/distributed_runtime:request_id", + "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr", + "//tensorflow/core/distributed_runtime:worker_cache", + ], +) + cc_library( name = "gdr_server_lib", srcs = ["gdr_server_lib.cc"], hdrs = ["gdr_server_lib.h"], linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel deps = [ + ":gdr_collective_executor_mgr", ":gdr_memory_manager", ":gdr_rendezvous_mgr", ":gdr_worker", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", + "//tensorflow/core/distributed_runtime:device_resolver_distributed", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], alwayslink = 1, diff --git a/tensorflow/contrib/gdr/README.md b/tensorflow/contrib/gdr/README.md index 8242d93f129904828a11b61d48f2df8fb0f88bc3..711adc865f37fc84550e4b45d9f0c7fff421a0dc 100644 --- a/tensorflow/contrib/gdr/README.md +++ b/tensorflow/contrib/gdr/README.md @@ -114,7 +114,16 @@ Caveats In current implementation, only tensors that reside in host memory or in GPU memory such that the GPU is adjacent to an RDMA capable NIC will use direct RDMA as its transport. When RDMA is available but not GDR, a temporary tensor copy on host memory will be used as RDMA source/destination (and copied from/to the target device). When there is no RDMA device present, it can even fallback to the original gRPC runtime. While it is theoretically possible to mix GDR enabled TF with non-GDR deployments in the same job, make sure the environment is properly setup so the GDR mode is enabled whenever possible (i.e. do not fall back to gRPC when it is not absolutely necessary). -In the original design (as in the reference), tensor buffers are only registered to NIC when we could determine that the tensor will be either a source of Send or a sink of Recv across physical machine boundary. However, to implement the precise allocations, we need to change all the devices to possibly return a NIC compatible allocator. As GDR is currently in contrib, we would like to avoid the unnecessary code disruption to the TF core, so we allocate all tensors from NIC-registered buffers using a BFC allocator. This behaviour is similar to the effect of enabling the extra GPU option `force_gpu_compatible`, which allocate all host tensors in GPU-registered buffers no matter they will be transferred from/to GPUs or not. +In the original design (as in the reference), tensor buffers are only registered +to NIC when we could determine that the tensor will be either a source of Send +or a sink of Recv across physical machine boundary. However, to implement the +precise allocations, we need to change all the devices to possibly return a NIC +compatible allocator. As GDR is currently in contrib, we would like to avoid the +unnecessary code disruption to the TF core, so we allocate all tensors from +NIC-registered buffers using a BFC allocator. This behavior is similar to the +effect of enabling the extra GPU option `force_gpu_compatible`, which allocate +all host tensors in GPU-registered buffers no matter they will be transferred +from/to GPUs or not. Reference === diff --git a/tensorflow/contrib/gdr/gdr.proto b/tensorflow/contrib/gdr/gdr.proto index c0b89245b150bfa49cb527d25b6e1f324f353b25..bd438787c3374be6ead4f6233101fd1f548643ea 100644 --- a/tensorflow/contrib/gdr/gdr.proto +++ b/tensorflow/contrib/gdr/gdr.proto @@ -9,5 +9,4 @@ message RemoteMemoryRegion { uint64 addr = 3; uint32 rkey = 4; uint32 tensor_key = 5; - uint64 checksum = 6; } diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc new file mode 100644 index 0000000000000000000000000000000000000000..b84710d26eb8a64bf2f86b9f920551a8a8dbb233 --- /dev/null +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h" + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/distributed_runtime/cancellable_call.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/request_id.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { + +class WorkerCacheInterface; + +namespace { + +class RecvBufCall : public CancellableCall { + public: + RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task, + const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + const DeviceLocality& server_locality, + CancellationManager* cancel_mgr, WorkerCacheInterface* wc) + : CancellableCall(cancel_mgr, peer_task, wc) { + req_.set_step_id(step_id); + req_.set_buf_rendezvous_key(key); + *req_.mutable_client_locality() = client_locality; + *req_.mutable_server_locality() = server_locality; + req_.set_num_bytes(to_tensor->TotalBytes()); + req_.set_buf_ptr(reinterpret_cast(DMAHelper::base(to_tensor))); + req_.set_src_device(peer_device); + req_.set_dst_device(to_device->name()); + req_.set_request_id(GetUniqueRequestId()); + } + + ~RecvBufCall() override {} + + void IssueCall(const StatusCallback& done) override { + wi_->RecvBufAsync(&opts_, &req_, &resp_, done); + } + + RecvBufRequest req_; + RecvBufResponse resp_; +}; + +class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { + public: + CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr, + DeviceResolverInterface* dev_resolver, + WorkerCacheInterface* worker_cache, + int64 step_id, + RemoteMemoryManager* remote_memory_manager) + : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id), + worker_cache_(worker_cache), + remote_memory_manager_(remote_memory_manager) {} + + ~CollectiveRemoteAccessDistributed() override {} + + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + int dev_to_dev_stream_index, + const StatusCallback& done) override { + if (peer_is_local) { + CollectiveRemoteAccessLocal::RecvFromPeer( + peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index, + done); + return; + } + + // State that needs to be threaded through a couple of async calls + // in order to make this function completely non-blocking. + struct State { + DeviceLocality server_locality; + std::unique_ptr call; + }; + State* state = new State; + + // Logic to be executed on the RecvBufAsync callback. + auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr, + to_device_ctx, to_tensor, dev_to_dev_stream_index, + done](const Status& s) { + if (s.ok()) { + remote_memory_manager_->TensorFromTransportOptions( + to_tensor, state->call->resp_.transport_options(), to_device, + to_device_ctx, to_alloc_attr.on_host(), done); + } + if (!s.ok() && errors::IsFailedPrecondition(s)) { + dev_resolver_->ClearTask(peer_task); + } + + delete state; + }; + + // Logic to execute once we have the device locality for the server-side + // device. + auto dev_locality_callback = [this, state, peer_device, peer_task, key, + to_device, to_device_ctx, to_alloc_attr, + to_tensor, client_locality, + recv_buf_callback](const Status& s) { + if (!s.ok()) { + recv_buf_callback(s); + } else { + state->call.reset(new RecvBufCall( + step_id_, peer_device, peer_task, key, to_device, to_device_ctx, + to_alloc_attr, to_tensor, client_locality, state->server_locality, + &cancel_mgr_, worker_cache_)); + state->call->Start(recv_buf_callback); + } + }; + + dev_resolver_->GetLocalityAsync( + peer_device, peer_task, &state->server_locality, dev_locality_callback); + } + + void StartAbort(const Status& s) override { + CollectiveRemoteAccessLocal::StartAbort(s); + cancel_mgr_.StartCancel(); + } + + protected: + WorkerCacheInterface* worker_cache_; // Not owned + CancellationManager cancel_mgr_; + RemoteMemoryManager* remote_memory_manager_; +}; + +} // namespace + +CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) { + CollectiveRemoteAccessDistributed* rma = + new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(), + worker_cache_, step_id, + remote_memory_manager_); + return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, + &gpu_ring_order_); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h new file mode 100644 index 0000000000000000000000000000000000000000..1417e51e82c31035f058e8e9b546e04fb0ad97b8 --- /dev/null +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ +#define TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ + +#include "tensorflow/contrib/gdr/gdr_memory_manager.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +class ConfigProto; +class DeviceMgr; +class WorkerCacheInterface; +class StepSequenceRequest; +class StepSequenceResponse; + +// An implementation of CollectiveExecutorMgr for a distributed environment +// that uses WorkerInterface::RecvBufAsync to route data transfers over RDMA. +class GdrCollectiveExecutorMgr : public RpcCollectiveExecutorMgr { + public: + GdrCollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* dev_mgr, + std::unique_ptr dev_resolver, + std::unique_ptr param_resolver, + WorkerCacheInterface* worker_cache, const string& task_name, + RemoteMemoryManager* remote_memory_manager) + : RpcCollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, + task_name), + remote_memory_manager_(remote_memory_manager) {} + + ~GdrCollectiveExecutorMgr() override {} + + protected: + virtual CollectiveExecutor* Create(int64 step_id) override; + + private: + RemoteMemoryManager* remote_memory_manager_; // Not owned. +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 53587fcf3050f313c85485f77ce411cba7faccff..7321e973191c4cc45f88735c6be7f2f67fe71c39 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -26,17 +26,14 @@ limitations under the License. #include #include #include -#include #include "tensorflow/contrib/gdr/gdr.pb.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/common_runtime/process_state.h" -#if GOOGLE_CUDA #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#endif // GOOGLE_CUDA +#include "tensorflow/core/common_runtime/process_state.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numa.h" @@ -76,15 +73,14 @@ int TryToReadNumaNode(ibv_device* device) { std::ifstream ifs(filename.c_str()); string content; - CHECK(std::getline(ifs, content)); + const auto& ret = std::getline(ifs, content); + if (!ret) { + return port::kNUMANoAffinity; + } int32 value; if (strings::safe_strto32(content, &value)) { if (value < 0) { - LOG(INFO) << "Successful NUMA node read from SysFS had negative value (" - << value - << "), but there must be at least one NUMA node" - ", so returning NUMA node zero"; return port::kNUMANoAffinity; } LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; @@ -114,7 +110,7 @@ class GdrMemoryManager : public RemoteMemoryManager { public: GdrMemoryManager(const string& host, const string& port); - virtual ~GdrMemoryManager(); + virtual ~GdrMemoryManager() {} virtual Status Init() override; @@ -140,7 +136,7 @@ class GdrMemoryManager : public RemoteMemoryManager { return ptr < reinterpret_cast(other->addr) + other->length; } - ibv_mr* FindMemoryRegion(void* addr, size_t length); + ibv_mr* FindMemoryRegion(const Tensor* tensor); void InsertMemoryRegion(void* addr, size_t length, const std::string& allocator_name); @@ -152,7 +148,6 @@ class GdrMemoryManager : public RemoteMemoryManager { const string port_; RdmaEndpointPtr listening_; std::atomic stopped_; - int epfd_; int numa_node_; // Server side endpoints @@ -163,15 +158,19 @@ class GdrMemoryManager : public RemoteMemoryManager { std::atomic next_key_; // Server side on-the-fly tensor buffers - mutex server_mu_; - std::map tensor_buffers_ - GUARDED_BY(server_mu_); + mutex buf_mu_; + std::map tensor_buffers_ GUARDED_BY(buf_mu_); // Client side endpoints mutex client_mu_; std::map, RdmaEndpointPtr> clients_ GUARDED_BY(client_mu_); + // Client side callbacks + mutex callback_mu_; + std::map tensor_callbacks_ + GUARDED_BY(callback_mu_); + // Managed memory regions mutex alloc_mu_; std::vector mrs_ GUARDED_BY(alloc_mu_); @@ -184,16 +183,9 @@ GdrMemoryManager::GdrMemoryManager(const string& host, const string& port) port_(port), listening_(nullptr, EndpointDeleter), stopped_(true), - next_key_(0) {} - -GdrMemoryManager::~GdrMemoryManager() { close(epfd_); } + next_key_(static_cast(random::New64())) {} Status GdrMemoryManager::Init() { - epfd_ = epoll_create1(0); - if (epfd_ == -1) { - return errors::Unavailable(strerror(errno), ": ", "epoll_create"); - } - rdma_addrinfo* addrinfo; rdma_addrinfo hints = {}; hints.ai_port_space = RDMA_PS_TCP; @@ -206,7 +198,7 @@ Status GdrMemoryManager::Init() { ibv_qp_init_attr init_attr = {}; init_attr.qp_type = IBV_QPT_RC; - init_attr.cap.max_recv_wr = 32; + init_attr.cap.max_recv_wr = 1024; init_attr.cap.max_send_wr = 1; init_attr.cap.max_recv_sge = 1; init_attr.cap.max_send_sge = 1; @@ -239,14 +231,6 @@ Status GdrMemoryManager::Init() { "cannot set server to non-blocking mode"); } - epoll_event event = {}; - event.events = EPOLLIN | EPOLLPRI; - event.data.ptr = listening_.get(); - if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) { - return errors::Unavailable(strerror(errno), ": ", - "cannot add server to epoll"); - } - numa_node_ = TryToReadNumaNode(listening_->verbs->device); SubAllocator::Visitor alloc_visitor = [this](void* ptr, int numa_node, @@ -265,121 +249,114 @@ Status GdrMemoryManager::Init() { ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); LOG(INFO) << "Instrumenting CPU allocator(s)"; -#if GOOGLE_CUDA for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { GPUProcessState::singleton()->AddCUDAHostAllocVisitor(numa_idx, alloc_visitor); GPUProcessState::singleton()->AddCUDAHostFreeVisitor(numa_idx, free_visitor); } + if (IsGDRAvailable()) { SubAllocator::Visitor cuda_alloc_visitor = [this](void* ptr, int gpu_id, size_t num_bytes) { VLOG(2) << "Registering RDMA capable memory region on GPU " << gpu_id; InsertMemoryRegion(ptr, num_bytes, strings::StrCat("GPU:", gpu_id)); }; - for (int numa_idx = 0; numa_idx < port::NUMANumNodes(); ++numa_idx) { - GPUProcessState::singleton()->AddGPUAllocVisitor(numa_idx, - cuda_alloc_visitor); - } - VLOG(1) << "Instrumenting GPU allocator(s) for all Numas"; + GPUProcessState::singleton()->AddGPUAllocVisitor(numa_node_, + cuda_alloc_visitor); + LOG(INFO) << "Instrumenting GPU allocator for NUMA " << numa_node_; } -#endif // GOOGLE_CUDA + return Status::OK(); } void GdrMemoryManager::Run() { stopped_ = false; while (!stopped_) { - epoll_event events[32]; - int ret = epoll_wait(epfd_, events, 32, 1); - if (ret == -1) { - LOG(ERROR) << "epoll_wait: " << strerror(errno); - return; - } - for (int i = 0; i < ret; i++) { - rdma_cm_id* id = static_cast(events[i].data.ptr); - if (id == listening_.get()) { - // Accept incoming connections - if (!rdma_get_request(listening_.get(), &id)) { - if (!rdma_accept(id, nullptr)) { - LOG(INFO) << "Accepted new RDMA connection"; - if (ibv_req_notify_cq(id->recv_cq, 0)) { - LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; - EndpointDeleter(id); - continue; - } - for (int i = 0; i < 32; i++) { - if (rdma_post_recvv(id, nullptr, nullptr, 0)) { - LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed"; - EndpointDeleter(id); - continue; - } - } - int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0); - if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) { - LOG(ERROR) << strerror(errno) - << ": cannot set server_client to non-blocking mode"; - EndpointDeleter(id); - continue; - } - epoll_event event = {}; - event.events = EPOLLIN | EPOLLPRI; - event.data.ptr = id; - if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd, - &event)) { - LOG(ERROR) << strerror(errno) - << ": cannot add server client to epoll"; - EndpointDeleter(id); - continue; - } - server_clients_.push_back({id, EndpointDeleter}); + rdma_cm_id* id = nullptr; + // Accept incoming connections + if (!rdma_get_request(listening_.get(), &id)) { + if (!rdma_accept(id, nullptr)) { + LOG(INFO) << "Accepted new RDMA connection"; + for (int i = 0; i < 1024; i++) { + if (rdma_post_recvv(id, nullptr, nullptr, 0)) { + LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed"; + EndpointDeleter(id); + continue; } } - } else { - // Polling work completions - ibv_cq* cq; - void* context; - if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) { - ibv_ack_cq_events(id->recv_cq, 1); - if (ibv_req_notify_cq(id->recv_cq, 0)) { - LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; - continue; + server_clients_.push_back({id, EndpointDeleter}); + } + } + // Polling server side work completions + for (const auto& client : server_clients_) { + ibv_wc wc[32]; + int ret = ibv_poll_cq(client->recv_cq, 32, wc); + if (ret < 0) { + LOG(ERROR) << "ibv_poll_cq failed"; + continue; + } + for (int i = 0; i < ret; i++) { + if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) { + LOG(ERROR) << "Received unknown operation " << wc[i].opcode; + } + if (wc[i].status != 0) { + LOG(ERROR) << ibv_wc_status_str(wc[i].status); + } + TensorKey tensor_key = ntohl(wc[i].imm_data); + + if (rdma_post_recvv(client.get(), nullptr, nullptr, 0)) { + perror("rdma_post_recvv"); + LOG(ERROR) << "rdma_post_recvv failed"; + } + + mutex_lock l(buf_mu_); + auto iter = tensor_buffers_.find(tensor_key); + if (iter == std::end(tensor_buffers_)) { + LOG(ERROR) << "Cannot find tensor buffer for tensor key " + << tensor_key; + } else { + const TensorBuffer* buffer = iter->second; + buffer->Unref(); + tensor_buffers_.erase(iter); + } + } + } + // Polling client side work completions + if (client_mu_.try_lock()) { + for (const auto& client : clients_) { + ibv_wc wc[32]; + int ret = ibv_poll_cq(client.second->send_cq, 32, wc); + for (int i = 0; i < ret; i++) { + Status s; + if (wc[i].status) { + s = errors::Unavailable(ibv_wc_status_str(wc[i].status)); + } else { + s = Status::OK(); } - ibv_wc wc[32]; - int ret = ibv_poll_cq(id->recv_cq, 32, wc); - if (ret < 0) { - LOG(ERROR) << "ibv_poll_cq failed"; - continue; + TensorKey key = wc[i].wr_id; + + ibv_send_wr wr = {}; + wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr.imm_data = htonl(key); + ibv_send_wr* bad_wr; + if (ibv_post_send(client.second->qp, &wr, &bad_wr)) { + LOG(ERROR) << strerror(errno) + << ": ibv_post_send failed for tensor_key " << key; } - for (int i = 0; i < ret; i++) { - if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) { - LOG(ERROR) << "Received unknown operation " << wc[i].opcode; - } - if (wc[i].status != 0) { - LOG(ERROR) << ibv_wc_status_str(wc[i].status); - } - TensorKey tensor_key = ntohl(wc[i].imm_data); - { - mutex_lock l(server_mu_); - auto iter = tensor_buffers_.find(tensor_key); - if (iter == std::end(tensor_buffers_)) { - LOG(ERROR) << "Cannot find tensor buffer for tensor key " - << tensor_key; - } else { - const TensorBuffer* buffer = iter->second; - buffer->Unref(); - tensor_buffers_.erase(iter); - } - } - if (rdma_post_recvv(id, nullptr, nullptr, 0)) { - perror("rdma_post_recvv"); - LOG(ERROR) << "rdma_post_recvv failed"; - continue; - } + + mutex_lock l(callback_mu_); + auto iter = tensor_callbacks_.find(key); + if (iter != std::end(tensor_callbacks_)) { + iter->second(s); + tensor_callbacks_.erase(iter); + } else { + LOG(WARNING) << "Cannot find client callback with tensor key " + << key; } } } + client_mu_.unlock(); } } } @@ -390,116 +367,58 @@ void GdrMemoryManager::TransportOptionsFromTensor( ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor, Device* device, DeviceContext* device_context, bool on_host, StatusCallback done) { - auto buffer = DMAHelper::buffer(&tensor); - void* addr = buffer->data(); - size_t length = buffer->size(); - if (length == 0) { - done(errors::Unavailable("Cannot register tensor buffer of size 0")); - return; - } + ibv_mr* mr = FindMemoryRegion(&tensor); + const TensorBuffer* buffer = DMAHelper::buffer(&tensor); - ibv_mr* mr = FindMemoryRegion(addr, length); - -#if GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); - Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); - GPUUtil::CopyGPUTensorToCPU( - device, device_context, &tensor, host_copy, - [done, host_copy, mutable_transport_options, this](const Status& s) { - if (!s.ok()) { - done(s); - delete host_copy; - return; - } - auto buffer = DMAHelper::buffer(host_copy); - void* addr = buffer->data(); - size_t length = buffer->size(); - ibv_mr* mr = FindMemoryRegion(addr, length); - - if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - delete host_copy; - return; - } - - buffer->Ref(); - TensorKey tensor_key = next_key_++; - { - mutex_lock l(server_mu_); - tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); - } - - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { - checksum = GPUUtil::Checksum(*host_copy); - } - - RemoteMemoryRegion remote_mr; - remote_mr.set_host(host_); - remote_mr.set_port(port_); - remote_mr.set_addr(reinterpret_cast(addr)); - remote_mr.set_rkey(mr->rkey); - remote_mr.set_tensor_key(tensor_key); - remote_mr.set_checksum(checksum); - mutable_transport_options->PackFrom(remote_mr); - - done(Status::OK()); - delete host_copy; - }); - return; - } -#endif + Tensor* copy = nullptr; if (mr == nullptr) { - Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); - Tensor host_copy(alloc, tensor.dtype(), tensor.shape()); - - std::memcpy(DMAHelper::buffer(&host_copy)->data(), buffer->data(), length); - VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; - - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - - mr = FindMemoryRegion(addr, length); + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_nic_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = device->GetAllocator(alloc_attrs); + copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); + + mr = FindMemoryRegion(copy); + buffer = DMAHelper::buffer(copy); if (mr == nullptr) { done(errors::Unavailable("Cannot find pinned memory region")); + delete copy; return; } - - buffer->Ref(); - } else { - buffer->Ref(); } TensorKey tensor_key = next_key_++; + buffer->Ref(); { - mutex_lock l(server_mu_); + mutex_lock l(buf_mu_); tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); } - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { -#ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - checksum = GPUUtil::Checksum(device, device_context, tensor); - } else { - checksum = GPUUtil::Checksum(tensor); - } -#endif - } - RemoteMemoryRegion remote_mr; remote_mr.set_host(host_); remote_mr.set_port(port_); - remote_mr.set_addr(reinterpret_cast(addr)); + remote_mr.set_addr(reinterpret_cast(buffer->data())); remote_mr.set_rkey(mr->rkey); remote_mr.set_tensor_key(tensor_key); - remote_mr.set_checksum(checksum); mutable_transport_options->PackFrom(remote_mr); - done(Status::OK()); + if (copy && device->tensorflow_gpu_device_info() && !on_host) { + device_context->CopyDeviceTensorToCPU(&tensor, "" /* tensor_name */, device, + copy, [done, copy](const Status& s) { + done(s); + delete copy; + }); + return; + } else if (copy) { + std::memcpy(buffer->data(), DMAHelper::buffer(&tensor)->data(), + buffer->size()); + done(Status::OK()); + delete copy; // OK to delete; we have reffed the underlying TensorBuffer + } else { + done(Status::OK()); + } } void GdrMemoryManager::TensorFromTransportOptions( @@ -512,42 +431,10 @@ void GdrMemoryManager::TensorFromTransportOptions( return; } - auto buffer = DMAHelper::buffer(tensor); - void* addr = buffer->data(); - size_t length = buffer->size(); - ibv_mr* mr = FindMemoryRegion(addr, length); - - Tensor host_copy; -#if GOOGLE_CUDA - if (mr == nullptr && !on_host) { - Allocator* alloc = - GPUProcessState::singleton()->GetCUDAHostAllocator(numa_node_); - host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - mr = FindMemoryRegion(addr, length); - } -#endif // GOOGLE_CUDA - - if (mr == nullptr) { - Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_); - host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); - - buffer = DMAHelper::buffer(&host_copy); - addr = buffer->data(); - length = buffer->size(); - - mr = FindMemoryRegion(addr, length); - if (mr == nullptr) { - done(errors::Unavailable("Cannot find pinned memory region")); - return; - } - } - - decltype(clients_)::iterator iter; - bool success; + rdma_cm_id* id = nullptr; { + decltype(clients_)::iterator iter; + bool success; mutex_lock l(client_mu_); std::tie(iter, success) = clients_.insert( std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()), @@ -560,93 +447,94 @@ void GdrMemoryManager::TensorFromTransportOptions( return; } } - } - rdma_cm_id* id = iter->second.get(); - - uint64_t start = Env::Default()->NowMicros(); - - if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0, - remote_mr.addr(), remote_mr.rkey())) { - done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed")); - return; + id = iter->second.get(); } - ibv_send_wr wr = {}; - wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr.imm_data = htonl(remote_mr.tensor_key()); - wr.send_flags = IBV_SEND_SIGNALED; - ibv_send_wr* bad_wr; - if (ibv_post_send(id->qp, &wr, &bad_wr)) { - done(errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed")); - return; - } + ibv_mr* mr = FindMemoryRegion(tensor); + const TensorBuffer* buffer = DMAHelper::buffer(tensor); - ibv_wc wc = {}; - int ret; - while ((ret = ibv_poll_cq(id->send_cq, 1, &wc)) == 0) - ; - if (ret < 0 || wc.status) { - done(errors::Unavailable(ibv_wc_status_str(wc.status))); - return; - } + const Tensor* copy = nullptr; -#if GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host && - host_copy.NumElements() > 0) { - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { - checksum = GPUUtil::Checksum(host_copy); - CHECK(checksum == remote_mr.checksum()) - << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); + if (mr == nullptr) { + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_nic_compatible(true); + alloc_attrs.set_on_host(true); + Allocator* alloc = device->GetAllocator(alloc_attrs); + copy = new Tensor(alloc, tensor->dtype(), tensor->shape()); + + mr = FindMemoryRegion(copy); + buffer = DMAHelper::buffer(copy); + if (mr == nullptr) { + done(errors::Unavailable("Cannot find pinned memory region")); + delete copy; + return; } - Tensor* ref = new Tensor; - std::swap(host_copy, *ref); - GPUUtil::CopyCPUTensorToGPU( - ref, device_context, device, tensor, - [ref, done, buffer, remote_mr, start](const Status& s) { - if (!s.ok()) { - done(s); - delete ref; - return; - } - uint64_t end = Env::Default()->NowMicros(); - - VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() - << " of size " << buffer->size() << " with tensor key " - << remote_mr.tensor_key() << " took " << (end - start) - << " micros"; - done(Status::OK()); - delete ref; - }); - return; } -#endif // GOOGLE_CUDA - if ((on_host || !device->tensorflow_gpu_device_info()) && - host_copy.NumElements() > 0) { - std::memcpy(DMAHelper::buffer(tensor)->data(), addr, length); - VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer"; - } + uint64_t start = Env::Default()->NowMicros(); - uint64_t end = Env::Default()->NowMicros(); + TensorKey tensor_key = remote_mr.tensor_key(); - VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() - << " of size " << buffer->size() << " with tensor key " - << remote_mr.tensor_key() << " took " << (end - start) << " micros"; + StatusCallback callback = [done, copy, device, device_context, on_host, + tensor, start, tensor_key](const Status& s) { + if (!s.ok()) { + done(s); + if (copy) { + delete copy; + } + return; + } - uint64_t checksum = 0; - if (VLOG_IS_ON(2)) { -#ifdef GOOGLE_CUDA - if (device->tensorflow_gpu_device_info() && !on_host) { - checksum = GPUUtil::Checksum(device, device_context, *tensor); + VLOG(2) << "RDMA of tensor " << tensor_key << " of size " + << DMAHelper::buffer(tensor)->size() << " took " + << (Env::Default()->NowMicros() - start) << " micros"; + + if (copy && device->tensorflow_gpu_device_info() && !on_host) { + device_context->CopyCPUTensorToDevice(copy, device, tensor, + [done, copy](const Status& s) { + done(s); + delete copy; + }); + } else if (copy) { + std::memcpy(DMAHelper::buffer(tensor)->data(), + DMAHelper::buffer(copy)->data(), + DMAHelper::buffer(copy)->size()); + done(s); + delete copy; } else { - checksum = GPUUtil::Checksum(*tensor); + done(s); + } + }; + + { + mutex_lock l(callback_mu_); + if (tensor_callbacks_.find(tensor_key) == std::end(tensor_callbacks_)) { + tensor_callbacks_.insert(std::make_pair(tensor_key, std::move(callback))); + } else { + done(errors::Unavailable("Received duplicated tensor key")); + if (copy) { + delete copy; + } + return; + } + } + + if (rdma_post_read(id, reinterpret_cast(tensor_key), buffer->data(), + buffer->size(), mr, IBV_SEND_SIGNALED, remote_mr.addr(), + remote_mr.rkey())) { + done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed")); + { + mutex_lock l(callback_mu_); + auto iter = tensor_callbacks_.find(tensor_key); + if (iter != std::end(tensor_callbacks_)) { + tensor_callbacks_.erase(iter); + } + } + if (copy) { + delete copy; } - CHECK(checksum == remote_mr.checksum()) - << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); -#endif } - done(Status::OK()); } Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, @@ -663,7 +551,7 @@ Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, ibv_qp_init_attr init_attr = {}; init_attr.qp_type = IBV_QPT_RC; init_attr.cap.max_recv_wr = 1; - init_attr.cap.max_send_wr = 32; + init_attr.cap.max_send_wr = 1024; init_attr.cap.max_recv_sge = 1; init_attr.cap.max_send_sge = 1; @@ -687,8 +575,8 @@ Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, return Status::OK(); } -ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) { - if (length == 0) return nullptr; +ibv_mr* GdrMemoryManager::FindMemoryRegion(const Tensor* tensor) { + const void* addr = DMAHelper::buffer(tensor)->data(); mutex_lock l(alloc_mu_); auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); if (iter == std::end(mrs_) || iter->get()->addr > addr) { diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index fbccbead03fc0d641db40ede661bf3677d44c45d..5f8c300155770ed03ad12a9fa5ac74456edaf024 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -58,11 +58,9 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { resp_.InitAlloc(dst_device_, recv_args_.alloc_attrs); StatusCallback cb = [this, recv_done](const Status& s) { bool dma_ok = resp_.metadata().has_transport_options(); - if (s.ok() && tensor().TotalBytes() > 0 && (!is_dead()) && dma_ok) { + if (s.ok() && tensor().TotalBytes() > 1024 && (!is_dead()) && dma_ok) { auto transport_options = resp_.metadata().transport_options(); - const bool on_host = - (dst_device_->tensorflow_gpu_device_info() == nullptr) || - recv_args_.alloc_attrs.on_host(); + const bool on_host = recv_args_.alloc_attrs.on_host(); remote_memory_manager_->TensorFromTransportOptions( const_cast(&tensor()), transport_options, dst_device_, recv_args_.device_context, on_host, @@ -70,9 +68,6 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { if (!s.ok()) { mutex_lock l(mu_); status_.Update(s); - LOG(ERROR) << "Cannot find pinned memory region from allocator " - << dst_device_->GetAllocator(recv_args_.alloc_attrs) - ->Name(); } recv_done(); }); diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index b3f48ec1dd9c75055f4e1ea76eb203b6ccf94718..c39cc0f9bcecc26aedfaf9707113210acf670244 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_server_lib.h" #include "grpc/support/alloc.h" +#include "tensorflow/contrib/gdr/gdr_collective_executor_mgr.h" #include "tensorflow/contrib/gdr/gdr_memory_manager.h" #include "tensorflow/contrib/gdr/gdr_rendezvous_mgr.h" #include "tensorflow/contrib/gdr/gdr_worker.h" - -#include "grpc/support/alloc.h" +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" namespace tensorflow { @@ -57,10 +59,34 @@ Status GdrServer::Init() { return std::unique_ptr( new GdrWorker(env, config, remote_memory_manager_.get())); }; - + CollectiveMgrCreationFunction collective_mgr_func = + [this](const ConfigProto& config, const WorkerEnv* env, + WorkerCacheInterface* worker_cache) { + string unused; + string default_worker_name; + DeviceNameUtils::SplitDeviceName( + env->device_mgr->ListDevices()[0]->name(), &default_worker_name, + &unused); + + std::unique_ptr dev_resolver( + new DeviceResolverDistributed(env->device_mgr, worker_cache, + default_worker_name)); + std::unique_ptr param_resolver( + new CollectiveParamResolverDistributed( + config, env->device_mgr, dev_resolver.get(), worker_cache, + default_worker_name)); + return new GdrCollectiveExecutorMgr( + config, env->device_mgr, std::move(dev_resolver), + std::move(param_resolver), worker_cache, default_worker_name, + remote_memory_manager_.get()); + }; TF_RETURN_IF_ERROR(remote_memory_manager_->Init()); - return GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func); + GrpcServerOptions opts; + opts.rendezvous_mgr_func = rendezvous_mgr_func; + opts.collective_mgr_func = collective_mgr_func; + opts.worker_func = worker_func; + return GrpcServer::Init(opts); } Status GdrServer::Start() { @@ -74,9 +100,8 @@ Status GdrServer::Start() { } Status GdrServer::Stop() { - TF_RETURN_IF_ERROR(GrpcServer::Stop()); remote_memory_manager_->Stop(); - return Status::OK(); + return GrpcServer::Stop(); } Status GdrServer::Join() { diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 867cb83f42034c8e9061e333ea671457745f92c3..1204b8ca501a8f99ea6abd6c047ab2d91350bae1 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -15,12 +15,10 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_worker.h" +#include "tensorflow/core/common_runtime/buf_rendezvous.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#if GOOGLE_CUDA -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#endif // GOOGLE_CUDA #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" @@ -32,6 +30,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_session.h" #include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -43,13 +42,13 @@ GdrWorker::GdrWorker(WorkerEnv* worker_env, const ConfigProto& config, RemoteMemoryManager* remote_memory_manager) : GrpcWorker(worker_env, config), remote_memory_manager_(remote_memory_manager), - recv_tensor_recent_request_ids_(100000) {} + recent_request_ids_(100000) {} void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, ::grpc::ByteBuffer* response, StatusCallback done) { - Status s = recv_tensor_recent_request_ids_.TrackUnique( + Status s = recent_request_ids_.TrackUnique( request->request_id(), "RecvTensor (GdrWorker)", *request); if (!s.ok()) { done(s); @@ -78,7 +77,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const bool dma_ok = request->dma_ok(); env_->rendezvous_mgr->RecvLocalAsync( step_id, parsed, - [this, opts, response, done, src_dev, dma_ok]( + [this, opts, response, done, src_dev, request, dma_ok]( const Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args&, const Tensor& val, const bool is_dead) { opts->ClearCancelCallback(); @@ -89,10 +88,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, // 3) the tensor has the on_host allocation attribute, // i.e. it's in CPU RAM *independent of its assigned // device type*. - const bool on_host = - (src_dev->tensorflow_gpu_device_info() == nullptr) || - send_args.alloc_attrs.on_host(); - if (val.TotalBytes() > 0 && (!is_dead) && + const bool on_host = send_args.alloc_attrs.on_host(); + if (val.TotalBytes() > 1024 && (!is_dead) && DMAHelper::CanUseDMA(&val) && dma_ok) { // DMA cases. RecvTensorResponse* proto = new RecvTensorResponse; @@ -117,8 +114,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, } else { // Non-DMA cases. if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { -#if GOOGLE_CUDA - const DeviceContext* send_dev_context = send_args.device_context; + DeviceContext* send_dev_context = send_args.device_context; AllocatorAttributes alloc_attrs; alloc_attrs.set_gpu_compatible(true); alloc_attrs.set_on_host(true); @@ -127,7 +123,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, CHECK(send_dev_context) << "send dev name: " << src_dev->name() << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); - // "val" is on a GPU. Uses GPUUtil to fill the response proto. + // "val" is on an accelerator device. Uses the device_context to + // fill the copy on host. StatusCallback copy_ready = [response, done, copy, is_dead](const Status& s) { // The value is now ready to be returned on the wire. @@ -136,11 +133,8 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, delete copy; }; - GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy, - copy_ready); -#else - done(errors::Internal("No GPU device in process")); -#endif // GOOGLE_CUDA + send_dev_context->CopyDeviceTensorToCPU( + &val, request->rendezvous_key(), src_dev, copy, copy_ready); } else { grpc::EncodeTensorToByteBuffer(is_dead, val, response); done(Status::OK()); @@ -153,4 +147,41 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, }); } +void GdrWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) { + // This is an RDMA enabled implementation augmenting grpc. + Status s = recent_request_ids_.TrackUnique(request->request_id(), + "RecvBuf (GdrWorker)", *request); + if (!s.ok()) { + done(s); + return; + } + CollectiveExecutor::Handle ce_handle( + env_->collective_executor_mgr->FindOrCreate(request->step_id()), true); + CollectiveRemoteAccess* rma = ce_handle.get()->remote_access(); + rma->buf_rendezvous()->ConsumeBuf( + request->buf_rendezvous_key(), + [this, request, response, done](const Status& status, + BufRendezvous::Hook* hook) { + Status s = status; + if (s.ok()) { + if (!DMAHelper::CanUseDMA(hook->prod_value)) { + s = errors::Internal("Tensor value for key ", + request->buf_rendezvous_key(), + " is not of a type supported by RecvBuf"); + } + } + if (s.ok()) { + remote_memory_manager_->TransportOptionsFromTensor( + response->mutable_transport_options(), *hook->prod_value, + hook->prod_dev, hook->prod_ctx, hook->prod_attr.on_host(), + [this, response, done, hook](const Status& s) { + response->set_send_start_micros(env_->env->NowMicros()); + done(s); + BufRendezvous::DoneWithHook(hook); + }); + } + }); +} + } // namespace tensorflow diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h index 39f11e6bde5a1ca7ae91ead02279d22d70af027b..9a85cfd4263ad86f6579eedce95969c2829ff62c 100644 --- a/tensorflow/contrib/gdr/gdr_worker.h +++ b/tensorflow/contrib/gdr/gdr_worker.h @@ -38,9 +38,13 @@ class GdrWorker : public GrpcWorker { ::grpc::ByteBuffer* response, StatusCallback done) override; + virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, + StatusCallback done) override; + private: RemoteMemoryManager* remote_memory_manager_; // Not owned - RecentRequestIds recv_tensor_recent_request_ids_; + RecentRequestIds recent_request_ids_; }; } // namespace tensorflow diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py index f7f1189bb93c611719186a697c40f371644f63a2..bc941ae9f23eaa5c46fcca95b9aba0ac0d87960a 100644 --- a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py +++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os from tensorflow.contrib.hadoop.python.ops import hadoop_dataset_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -47,7 +48,7 @@ class SequenceFileDatasetTest(test.TestCase): dataset = hadoop_dataset_ops.SequenceFileDataset(filenames).repeat( num_repeats) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index bf398b838dfaaff6fdaf33a6cd7086ef13e43a3e..71eac729a8a81c2f59f9ed5d7f42fb7b1c3e1b5c 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -20,15 +20,19 @@ from __future__ import print_function from tensorflow.contrib.hadoop.python.ops import gen_dataset_ops from tensorflow.contrib.hadoop.python.ops import hadoop_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import deprecation class SequenceFileDataset(dataset_ops.DatasetSource): """A Sequence File Dataset that reads the sequence file.""" + @deprecation.deprecated( + None, + "tf.contrib.hadoop will be removed in 2.0, the support for Apache Hadoop " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, filenames): """Create a `SequenceFileDataset`. @@ -40,36 +44,25 @@ class SequenceFileDataset(dataset_ops.DatasetSource): For example: ```python + tf.enable_eager_execution() + dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq") - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() # Prints the (key, value) pairs inside a hadoop sequence file. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break + for key, value in dataset: + print(key, value) ``` Args: filenames: A `tf.string` tensor containing one or more filenames. """ - super(SequenceFileDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") - - def _as_variant_tensor(self): - return gen_dataset_ops.sequence_file_dataset( - self._filenames, nest.flatten(self.output_types)) - - @property - def output_classes(self): - return ops.Tensor, ops.Tensor - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + variant_tensor = gen_dataset_ops.sequence_file_dataset( + self._filenames, self._element_structure._flat_types) # pylint: disable=protected-access + super(SequenceFileDataset, self).__init__(variant_tensor) @property - def output_types(self): - return dtypes.string, dtypes.string + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index 0081fb61770075a2c36e92f65e01126f657edeb4..d319aa7986d81cf9ac2d1dc2e15b053a0aa0c31b 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -16,9 +16,22 @@ tf_cc_binary( srcs = ["hvx_ops_support_checker_main.cc"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:candidate_sampling_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:list_ops_op_lib", + "//tensorflow/core:manip_ops_op_lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:random_ops_op_lib", + "//tensorflow/core:remote_fused_graph_ops_op_lib", + "//tensorflow/core:string_ops_op_lib", + "//tensorflow/core:training_ops_op_lib", + "//tensorflow/core:user_ops_op_lib", "//tensorflow/core/kernels:remote_fused_graph_execute_utils", "//tensorflow/core/kernels/hexagon:graph_transferer", "//tensorflow/tools/graph_transforms:file_utils", diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md index c7db0b77e25668fb8a42d204776044420f403e44..c1f6cac4942436d32f9867d4b5557c6b9e376c69 100644 --- a/tensorflow/contrib/ignite/README.md +++ b/tensorflow/contrib/ignite/README.md @@ -30,7 +30,8 @@ system based on Apache Ignite. ## Features -Ignite Dataset provides features that that you can use in a wide range of cases. The most important and interesting features are described below. +Ignite Dataset provides features that you can use in a wide range of cases. The +most important and interesting features are described below. ### Distributed In-Memory Datasource [Apache Ignite](https://ignite.apache.org/) is a distributed in-memory database, caching, and processing platform that provides fast data access. It allows you to avoid limitations of hard drive and store and operate with as much data as you need in distributed cluster. You can utilize @@ -54,14 +55,12 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> tf.enable_eager_execution() +>>> >>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE") ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() >>> ->>> with tf.Session() as sess: ->>> for _ in range(3): ->>> print(sess.run(next_obj)) +>>> for element in dataset: +>>> print(element) {'key': 1, 'val': {'NAME': b'WARM KITTY'}} {'key': 2, 'val': {'NAME': b'SOFT KITTY'}} @@ -74,23 +73,22 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> tf.enable_eager_execution() +>>> >>> dataset = IgniteDataset(cache_name="IMAGES") ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() >>> ->>> with tf.Session() as sess: ->>> print(sess.run(next_obj)) +>>> for element in dataset.take(1): +>>> print(element) { - 'key': 'kitten.png', + 'key': 'kitten.png', 'val': { 'metadata': { 'file_name': b'kitten.png', 'label': b'little ball of fur', - width: 800, + width: 800, height: 600 - }, + }, 'pixels': [0, 0, 0, 0, ..., 0] } } @@ -100,13 +98,12 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> tf.enable_eager_execution() +>>> >>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels']) ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() >>> ->>> with tf.Session() as sess: ->>> print(sess.run(next_obj)) +>>> for element in dataset: +>>> print(element) [0, 0, 0, 0, ..., 0] ``` @@ -121,23 +118,31 @@ Using this ability we can calculate gradients on the nodes the data is stored on Apache Ignite uses horizontal partitioning to store data in distributed cluster. When we create Apache Ignite cache (or table in terms of SQL), we can specify the number of partitions the data will be partitioned on. For example, if an Apache Ignite cluster consists of 10 machines and we create cache with 10 partitions, then every machine will maintain approximately one data partition. -Ignite Dataset allows using these two aspects of distributed neural network training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a computation graph operation that can be performed on a remote worker. The remote worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) by setting correstondent environment variables for worker process (such as `IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using this overriding approach, we can assign a specific partition to every worker so that one worker handles one partition and, at the same time, transparently work with single dataset. +Ignite Dataset allows using these two aspects of distributed neural network +training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a +computation graph operation that can be performed on a remote worker. The remote +worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) +by setting correspondent environment variables for worker process (such as +`IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using +this overriding approach, we can assign a specific partition to every worker so +that one worker handles one partition and, at the same time, transparently work +with single dataset. ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> >>> dataset = IgniteDataset("IMAGES") >>> >>> # Compute gradients locally on every worker node. ->>> gradients = [] +>>> gradients = [] >>> for i in range(5): >>> with tf.device("/job:WORKER/task:%d" % i): ->>> device_iterator = dataset.make_one_shot_iterator() +>>> device_iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) >>> device_next_obj = device_iterator.get_next() >>> gradient = compute_gradient(device_next_obj) ->>> gradients.append(gradient) ->>> +>>> gradients.append(gradient) +>>> >>> # Aggregate them on master node. >>> result_gradient = tf.reduce_sum(gradients) >>> @@ -145,7 +150,7 @@ Ignite Dataset allows using these two aspects of distributed neural network trai >>> print(sess.run(result_gradient)) ``` -High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. +High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. ### Distributed File System @@ -154,23 +159,31 @@ system called [IGFS](https://ignite.apache.org/features/igfs.html). IGFS delivers a similar functionality to Hadoop HDFS, but only in-memory. In fact, in addition to its own APIs, IGFS implements Hadoop FileSystem API and can be transparently plugged into Hadoop or Spark deployments. This contrib package -contains an integration between IGFS and TensorFlow. The integration is based -on [custom filesystem plugin](https://www.tensorflow.org/extend/add_filesys) -from TensorFlow side and +contains an integration between IGFS and TensorFlow. The integration is based on +[custom filesystem plugin](https://www.tensorflow.org/extend/add_filesys) from +TensorFlow side and [IGFS Native API](https://ignite.apache.org/features/igfs.html) from Apache -Ignite side. It has numerous uses, for example: * Checkpoints of state can be -saved to IGFS for reliability and fault-tolerance. * Training processes -communicate with TensorBoard by writing event files to a directory, which -TensorBoard watches. IGFS allows this communication to work even when -TensorBoard runs in a different process or machine. +Ignite side. It has numerous uses, for example: + +* Checkpoints of state can be saved to IGFS for reliability and + fault-tolerance. +* Training processes communicate with TensorBoard by writing event files to a + directory, which TensorBoard watches. IGFS allows this communication to work + even when TensorBoard runs in a different process or machine. ### SSL Connection -Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation. +Apache Ignite allows to protect data transfer channels by +[SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and +authentication. Ignite Dataset supports both SSL connection with and without +authentication. For more information, please refer to the +[Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) +documentation. ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset +>>> tf.enable_eager_execution() >>> >>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", @@ -191,7 +204,7 @@ Following examples will help you to easily start working with this module. The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded -[MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with +[MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interrupt with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine: @@ -202,13 +215,13 @@ docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist After that you will be able to work with it following way: -![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist") +![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist-2.png "Ignite Dataset Mnist") ### IGFS The simplest way to try IGFS with TensorFlow is to run [Docker](https://www.docker.com/) container with Apache Ignite and enabled IGFS -and then interruct with it using TensorFlow +and then interrupt with it using TensorFlow [tf.gfile](https://www.tensorflow.org/api_docs/python/tf/gfile). Such container is available on Docker Hub: [dmitrievanthony/ignite-with-igfs](https://hub.docker.com/r/dmitrievanthony/ignite-with-igfs/). diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py index 936b29a4f50794380d48efed99e267c6b4c44dc6..3ffceef8070e0fc3b3cebae2522f89fe98ce4413 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -27,17 +27,16 @@ import six from tensorflow.contrib.ignite.python.ops import gen_dataset_ops from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import deprecation @six.add_metaclass(abc.ABCMeta) class Readable(object): - """Readable abstract class that exposes methods to do reading-related - - operations. - """ + """Abstract class that exposes methods to do reading-related operations.""" @abc.abstractmethod def __init__(self): @@ -227,10 +226,7 @@ types = { class TypeTreeNode(object): - """TypeTreeNode class exposes methods to format object tree structure - - data. - """ + """TypeTreeNode class exposes methods to format object tree structure data.""" def __init__(self, name, type_id, fields=None, permutation=None): """Constructs a new instance of TypeTreeNode. @@ -692,18 +688,22 @@ class IgniteClient(TcpClient): class IgniteDataset(dataset_ops.DatasetSource): - """Apache Ignite is a memory-centric distributed database, caching, and - - processing platform for transactional, analytical, and streaming workloads, - delivering in-memory speeds at petabyte scale. This contrib package - contains an integration between Apache Ignite and TensorFlow. The - integration is based on tf.data from TensorFlow side and Binary Client - Protocol from Apache Ignite side. It allows to use Apache Ignite as a - datasource for neural network training, inference and all other + """Apache Ignite is a memory-centric distributed database. + + It acts as a caching and processing platform for transactional, analytical, + and streaming workloads, delivering in-memory speeds at petabyte scale. + This contrib package contains an integration between Apache Ignite and + TensorFlow. The integration is based on tf.data from TensorFlow side and + Binary Client Protocol from Apache Ignite side. It allows to use Apache + Ignite as a datasource for neural network training, inference and all other computations supported by TensorFlow. Ignite Dataset is based on Apache Ignite Binary Client Protocol. """ + @deprecation.deprecated( + None, + "tf.contrib.ignite will be removed in 2.0, the support for Apache Ignite " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, cache_name, host="localhost", @@ -735,8 +735,6 @@ class IgniteDataset(dataset_ops.DatasetSource): cert_password: Password to be used if the private key is encrypted and a password is necessary. """ - super(IgniteDataset, self).__init__() - with IgniteClient(host, port, username, password, certfile, keyfile, cert_password) as client: client.handshake() @@ -756,6 +754,11 @@ class IgniteDataset(dataset_ops.DatasetSource): self.cache_type.to_permutation(), dtype=dtypes.int32, name="permutation") + self._structure = structure.convert_legacy_structure( + self.cache_type.to_output_types(), self.cache_type.to_output_shapes(), + self.cache_type.to_output_classes()) + + super(IgniteDataset, self).__init__(self._as_variant_tensor()) def _as_variant_tensor(self): return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port, @@ -763,13 +766,5 @@ class IgniteDataset(dataset_ops.DatasetSource): self.schema, self.permutation) @property - def output_classes(self): - return self.cache_type.to_output_classes() - - @property - def output_shapes(self): - return self.cache_type.to_output_shapes() - - @property - def output_types(self): - return self.cache_type.to_output_types() + def _element_structure(self): + return self._structure diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py index ef29b5f14a4b2fea2400ec4d56a7ad2cf44cf2cb..89b74fbfdc38c9f42795d5c778889210baf6387f 100644 --- a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py +++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import os +from tensorflow import compat from tensorflow.contrib.ignite import IgniteDataset from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -65,7 +66,7 @@ class IgniteDatasetTest(test.TestCase): self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"]) self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"]) - it = dataset.make_one_shot_iterator() + it = compat.v1.data.make_one_shot_iterator(dataset) ne = it.get_next() with session.Session() as sess: diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index 4997c31a7fc7f4243d03b22fc9c01fb13a2a25a4..ba5cdfebf92c07e496ed588848d5859ff6a5bff2 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -281,6 +281,13 @@ class ImageOpsTest(test_util.TensorFlowTestCase): value.eval(), np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype())) + @test_util.run_in_graph_and_eager_modes + def test_transform_eager(self): + image = constant_op.constant([[1., 2.], [3., 4.]]) + value = image_ops.transform(image, [1] * 8) + with self.test_session(use_gpu=True): + self.assertAllEqual(self.evaluate(value), np.array([[4, 4], [4, 4]])) + class BipartiteMatchTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index d4fb99a017faebe30384d739f22f4ff5fa986bc4..b25a6f7b5742917a032946fe03a0dab20e7dc1ad 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.contrib.image.ops import gen_image_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import common_shapes @@ -271,8 +272,11 @@ def transform(images, raise TypeError("Images should have rank between 2 and 4.") if output_shape is None: - output_shape = tensor_util.constant_value( - array_ops.shape(images)[1:3]) or array_ops.shape(images)[1:3] + output_shape = array_ops.shape(images)[1:3] + if not context.executing_eagerly(): + output_shape_value = tensor_util.constant_value(output_shape) + if output_shape_value is not None: + output_shape = output_shape_value output_shape = ops.convert_to_tensor( output_shape, dtypes.int32, name="output_shape") diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh old mode 100644 new mode 100755 diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index 7129f09e8b42e48a9c768fd4a66cde3d4da9d31d..5591c3b0cc8c8bf196bb4821c018cbf155cba4ce 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -20,15 +20,20 @@ from __future__ import print_function from tensorflow.contrib.kafka.python.ops import gen_dataset_ops from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import deprecation class KafkaDataset(dataset_ops.DatasetSource): """A Kafka Dataset that consumes the message. """ + @deprecation.deprecated( + None, + "tf.contrib.kafka will be removed in 2.0, the support for Apache Kafka " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, topics, servers="localhost", @@ -47,7 +52,6 @@ class KafkaDataset(dataset_ops.DatasetSource): timeout: The timeout value for the Kafka Consumer to wait (in millisecond). """ - super(KafkaDataset, self).__init__() self._topics = ops.convert_to_tensor( topics, dtype=dtypes.string, name="topics") self._servers = ops.convert_to_tensor( @@ -58,18 +62,12 @@ class KafkaDataset(dataset_ops.DatasetSource): self._timeout = ops.convert_to_tensor( timeout, dtype=dtypes.int64, name="timeout") + super(KafkaDataset, self).__init__(self._as_variant_tensor()) + def _as_variant_tensor(self): return gen_dataset_ops.kafka_dataset(self._topics, self._servers, self._group, self._eof, self._timeout) @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.scalar() - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index 4ef0a66a52429233c6e6f70667a451466493629c..294a7d69a704b3c06ab9e30489af116929ab6c2a 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -34,7 +34,7 @@ def sparse_multiclass_hinge_loss( scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS): - """Adds Ops for computing the multiclass hinge loss. + r"""Adds Ops for computing the multiclass hinge loss. The implementation is based on the following paper: On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 42b91d031375b8edb7e4f364ac91ffb74ef1f54b..19daffea6c7e4486499388314d0aaaa611e94218 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,3 +1,3 @@ # K-FAC: Kronecker-Factored Approximate Curvature -## KFAC moved to third_party/tensorflow_kfac. +## KFAC moved to https://github.com/tensorflow/kfac. diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py index 75806dbbeb1819bb0a6965bbc384e02df9895210..9479afb180df7bb4a08d6aafa4fc3bf63489d9f3 100644 --- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -20,9 +20,10 @@ from __future__ import print_function from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import deprecation class KinesisDataset(dataset_ops.DatasetSource): @@ -34,15 +35,12 @@ class KinesisDataset(dataset_ops.DatasetSource): For example, we can construct and use the KinesisDataset as follows: ```python + tf.enable_eager_execution() + dataset = tf.contrib.kinesis.KinesisDataset( "kinesis_stream_name", read_indefinitely=False) - next = dataset.make_one_shot_iterator().get_next() - with tf.Session() as sess: - while True: - try: - print(sess.run(nxt)) - except tf.errors.OutOfRangeError: - break + for element in dataset: + print(element) ``` Since Kinesis is a data streaming service, data may not be available @@ -53,6 +51,10 @@ class KinesisDataset(dataset_ops.DatasetSource): is returned immediately instead. """ + @deprecation.deprecated( + None, + "tf.contrib.kinesis will be removed in 2.0, the support for Kinesis " + "will continue to be provided through the tensorflow/io GitHub project.") def __init__(self, stream, shard="", @@ -69,7 +71,6 @@ class KinesisDataset(dataset_ops.DatasetSource): interval: The interval for the Kinesis Client to wait before it tries to get records again (in millisecond). """ - super(KinesisDataset, self).__init__() self._stream = ops.convert_to_tensor( stream, dtype=dtypes.string, name="stream") self._shard = ops.convert_to_tensor( @@ -78,19 +79,12 @@ class KinesisDataset(dataset_ops.DatasetSource): read_indefinitely, dtype=dtypes.bool, name="read_indefinitely") self._interval = ops.convert_to_tensor( interval, dtype=dtypes.int64, name="interval") + super(KinesisDataset, self).__init__(self._as_variant_tensor()) def _as_variant_tensor(self): return gen_dataset_ops.kinesis_dataset( self._stream, self._shard, self._read_indefinitely, self._interval) @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.scalar() - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index 588f15b867c1fedbadd5a5d945d870a356549468..7e19ae7c13df421ec5bb9cb0e07dff0d00fb9548 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -155,7 +155,7 @@ py_library( ":core", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:functional_ops", + "//tensorflow/python:map_fn", "//tensorflow/python:math_ops", "//tensorflow/python:numerics", "//tensorflow/python:random_ops", diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index 2ede5daee74223e812cc29e9708b1989b698fb4e..a65f045cc886f4d4f351423858d92412baa3a622 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import map_fn as map_fn_lib from tensorflow.python.ops import math_ops from tensorflow.python.ops import numerics from tensorflow.python.ops import random_ops @@ -629,7 +630,7 @@ def map_fn(fn, labeled_tensor, name=None): # TODO(ericmc): Fix this upstream. if labeled_tensor.dtype == dtypes.string: - # We must construct the full graph here, because functional_ops.map_fn + # We must construct the full graph here, because map_fn_lib.map_fn # doesn't work for string-valued tensors. # Constructing the full graph may be slow. map_lts = [fn(t) for t in unpack_lts] @@ -652,7 +653,7 @@ def map_fn(fn, labeled_tensor, name=None): tensor_lt = core.LabeledTensor(tensor, original_axes) return fn(tensor_lt).tensor - map_op = functional_ops.map_fn( + map_op = map_fn_lib.map_fn( tf_fn, labeled_tensor.tensor, dtype=first_map_lt.dtype) map_lt = core.LabeledTensor(map_op, final_axes) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 9ca6f8df5dbe3c236c4cd85095176ce69ad9deaa..69d5496f8aebb9b89c5d79f80a1a439f556093d7 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -81,6 +81,7 @@ tf_custom_op_py_library( visibility = [ "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", + "//tensorflow_model_optimization:__subpackages__", "//video/youtube/personalization:__subpackages__", ], deps = [ diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 7e6eafaa0d6f60cfc28a4c422abac0b6d5a991fb..00e41026d0038409ace178e6affd2c1cdc812122 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -1757,7 +1757,7 @@ class WeightedSumTest(test.TestCase): logits_core = fc_core.linear_model(features, [movies]) with self.cached_session() as sess: - variables_lib.initialize_all_variables().run() + variables_lib.global_variables_initializer().run() lookup_ops.tables_initializer().run() weights = column_to_variable[movies][0] diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 0a4d2c6d4cb5cad7da93cea89478bc0fca2ac4d6..1c0088186c030437454c0f764decab9e5a276adc 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1356,7 +1356,7 @@ class DropoutTest(test.TestCase): with self.cached_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.dropout(images) - self.assertEqual(output.op.name, 'Dropout/dropout_1/mul') + self.assertEqual(output.op.name, 'Dropout/dropout_1/mul_1') output.get_shape().assert_is_compatible_with( ops.convert_to_tensor(images).get_shape()) @@ -1459,13 +1459,6 @@ class DropoutTest(test.TestCase): class FlattenTest(test.TestCase): - def testInvalidRank(self): - with ops.Graph().as_default() as g, self.session(g): - inputs = array_ops.placeholder(dtype=dtypes.float32) - inputs.set_shape(tensor_shape.TensorShape((5,))) - with self.assertRaisesRegexp(ValueError, 'incompatible with the layer'): - _layers.flatten(inputs) - def testUnknownLastDim(self): with ops.Graph().as_default() as g, self.session(g): inputs = array_ops.placeholder(dtype=dtypes.float32) @@ -1502,6 +1495,12 @@ class FlattenTest(test.TestCase): images.get_shape().num_elements()) self.assertEqual(output.get_shape()[0], images.get_shape()[0]) + def testFlatten0D(self): + with self.cached_session(): + scalars = random_ops.random_uniform((5,), seed=1, name='scalars') + output = _layers.flatten(scalars) + self.assertEqual(output.shape, (5, 1)) + def testFlattenBatchSize(self): height, width = 3, 3 with self.cached_session() as sess: diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py index 11033a2e9cb646c2e7cd2f45de1f751d88c6921a..76b03ff514821d3459f84c5f46a64d1134e0d4de 100644 --- a/tensorflow/contrib/layers/python/layers/normalization.py +++ b/tensorflow/contrib/layers/python/layers/normalization.py @@ -186,7 +186,7 @@ def group_norm(inputs, Args: inputs: A Tensor with at least 2 dimensions one which is channels. All - shape dimensions must be fully defined. + shape dimensions except for batch must be fully defined. groups: Integer. Divide the channels into this number of groups over which normalization statistics are computed. This number must be commensurate with the number of channels in `inputs`. @@ -249,13 +249,21 @@ def group_norm(inputs, """ # TODO(shlens): Support partially defined shapes for the inputs. inputs = ops.convert_to_tensor(inputs) - original_shape = inputs.shape if inputs.shape.ndims is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) if channels_axis > (inputs.shape.ndims - 1): raise ValueError('Axis is out of bounds.') + # Use dynamic shape for not fully defined dimensions in the inputs. + dyanmic_shape = array_ops.shape(inputs) + input_shape_list = [] + for i, dim in enumerate(inputs.shape): + if dim.value is None: + input_shape_list.append(dyanmic_shape[i]) + else: + input_shape_list.append(dim) + # Standardize the channels_axis to be positive and identify # of channels. if channels_axis < 0: channels_axis = inputs.shape.ndims + channels_axis @@ -289,8 +297,8 @@ def group_norm(inputs, # Determine axes before channels. Some examples of common image formats: # 'NCHW': before = [N], after = [HW] # 'NHWC': before = [NHW], after = [] - axes_before_channels = inputs.shape.as_list()[:channels_axis] - axes_after_channels = inputs.shape.as_list()[channels_axis+1:] + axes_before_channels = input_shape_list[:channels_axis] + axes_after_channels = input_shape_list[channels_axis+1:] # Manually broadcast the parameters to conform to the number of groups. params_shape_broadcast = ([1] * len(axes_before_channels) + @@ -369,7 +377,7 @@ def group_norm(inputs, outputs = inputs * gain + offset # Collapse the groups into the channel dimension. - outputs = array_ops.reshape(outputs, original_shape) + outputs = array_ops.reshape(outputs, input_shape_list) if activation_fn is not None: outputs = activation_fn(outputs) diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py index c8d3c91b10dbe3b959e91182f9924b78352d370d..9a85084b239837ade87d8c778393ef8e885f5bdd 100644 --- a/tensorflow/contrib/layers/python/layers/normalization_test.py +++ b/tensorflow/contrib/layers/python/layers/normalization_test.py @@ -221,6 +221,15 @@ class GroupNormTest(test.TestCase): normalization.group_norm(inputs, channels_axis=-1, reduction_axes=[-3, -2]) + def testParamsShapeNotFullyDefinedBatchAxis(self): + height, width, groups = 3, 3, 4 + inputs = array_ops.placeholder(dtypes.float32, + shape=(None, height, width, 2*groups)) + output = normalization.group_norm(inputs, channels_axis=-1, + reduction_axes=[-3, -2], groups=groups) + self.assertListEqual([None, height, width, 2 * groups], + output.shape.as_list()) + def testCreateOp(self): height, width, groups = 3, 3, 4 images = random_ops.random_uniform((5, height, width, 2*groups), seed=1) diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 8a6b4f68a8b33d497ddb16614a7e3cdf32f2c422..5234869718b427d7e275b76ae12021a096241a56 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -399,7 +399,7 @@ def _mean_squared_loss(logits, target): target = array_ops.expand_dims(target, axis=1) logits.get_shape().assert_is_compatible_with(target.get_shape()) - return math_ops.square(logits - math_ops.to_float(target)) + return math_ops.squared_difference(logits, math_ops.to_float(target)) def _log_loss_with_two_classes(logits, target): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 238504f6d60aeb1a7ff25deab4a86881285e8c03..4749371248ee89a033912132986d7f76c85dbaa6 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -274,6 +274,7 @@ py_test( name = "estimator_test", size = "medium", srcs = ["python/learn/estimators/estimator_test.py"], + shard_count = 2, srcs_version = "PY2AND3", tags = [ "manual", @@ -356,9 +357,9 @@ py_test( py_test( name = "dnn_linear_combined_test", - size = "large", + size = "medium", srcs = ["python/learn/estimators/dnn_linear_combined_test.py"], - shard_count = 4, + shard_count = 8, srcs_version = "PY2AND3", tags = ["no_oss"], # flaky b/70524820 deps = [ diff --git a/tensorflow/contrib/learn/README.md b/tensorflow/contrib/learn/README.md index b0bff915a993c9a01e2e6d9ef9f71c14d2f29a73..b2d3a6273abba7e3a893f30bbdd4f8b2662bd54a 100644 --- a/tensorflow/contrib/learn/README.md +++ b/tensorflow/contrib/learn/README.md @@ -111,18 +111,17 @@ Some arguments are renamed, please refer to documentation. In addition: Switch to `tf.estimator.train_and_evaluate`. Some differences: -* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, - should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. -* Remove the `experiment_fn`. Instead, create the `Estimator`, - `train_spec` and `eval_spec`, then call `tf.estimator.train_and_evaluate` - directly. -* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement - for `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the - replacement for `tf.contrib.learn.make_export_strategy`. If you want to export - only at the end of training use `tf.estimator.FinalExporter`. -* If the `TF_CONFIG` environment variable is constructed manually, please read - the `train_and_evaluate` documentation for the new requirementds (in - particular, the chief node and evaluator node). +* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, + should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. +* Remove the `experiment_fn`. Instead, create the `Estimator`, `train_spec` + and `eval_spec`, then call `tf.estimator.train_and_evaluate` directly. +* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement for + `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the + replacement for `tf.contrib.learn.make_export_strategy`. If you want to + export only at the end of training use `tf.estimator.FinalExporter`. +* If the `TF_CONFIG` environment variable is constructed manually, please read + the `train_and_evaluate` documentation for the new requirements (in + particular, the chief node and evaluator node). ## Others Classes and Functions diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 28c4964527bb034c8c6b1642366c6c82c1a72201..c3e9e3af9427037a4e7be6b86417cd081c42ef67 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -37,8 +37,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell @@ -524,7 +524,7 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): def input_fn(): starts = random_ops.random_uniform( [batch_size], maxval=(2 * np.pi), seed=seed) - sin_curves = functional_ops.map_fn( + sin_curves = map_fn.map_fn( _sin_fn, (starts,), dtype=dtypes.float32) inputs = array_ops.expand_dims( array_ops.slice(sin_curves, [0, 0], [batch_size, sequence_length]), diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index c1b97d8b49613ea49d9813954da3b7a63d3ba04c..4bb14a6e63b159fa4d09c9ef20947d4b125de657 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -567,7 +567,8 @@ def _mean_squared_loss(labels, logits, weights=None): if len(logits.get_shape()) == 1: logits = array_ops.expand_dims(logits, axis=1) logits.get_shape().assert_is_compatible_with(labels.get_shape()) - loss = math_ops.square(logits - math_ops.to_float(labels), name=name) + loss = math_ops.squared_difference( + logits, math_ops.to_float(labels), name=name) return _compute_weighted_loss(loss, weights) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py index 5e90d1fa20535de3b5e25bc7ff8c3862cea5514c..318046733bf75a6d661d26f478118c8e944afe15 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py @@ -174,7 +174,7 @@ class GeneratorIoTest(test.TestCase): return np.arange(32, 36) with self.cached_session(): - with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'): + with self.assertRaisesRegexp(TypeError, r'x\(\) must be generator'): failing_input_fn = generator_io.generator_input_fn( generator, batch_size=2, shuffle=False, num_epochs=1) failing_input_fn() @@ -185,7 +185,7 @@ class GeneratorIoTest(test.TestCase): yield np.arange(32, 36) with self.cached_session(): - with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'): + with self.assertRaisesRegexp(TypeError, r'x\(\) must yield dict'): failing_input_fn = generator_io.generator_input_fn( generator, batch_size=2, shuffle=False, num_epochs=1) failing_input_fn() diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py index e7d091e18a8f186f89f5217442c24fb106c5cdab..af93e517f51ed33a8968982945ac1f65ec915ab1 100644 --- a/tensorflow/contrib/learn/python/learn/utils/gc_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py @@ -36,10 +36,10 @@ def _create_parser(base_dir): # Modify the path object for RegEx match for Windows Paths if os.name == "nt": match = re.match( - "^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$", + r"^" + compat.as_str_any(base_dir).replace("\\", "/") + r"/(\d+)$", compat.as_str_any(path.path).replace("\\", "/")) else: - match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$", + match = re.match(r"^" + compat.as_str_any(base_dir) + r"/(\d+)$", compat.as_str_any(path.path)) if not match: return None diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 8466dc36d13e223aed4f1dfe8e39a6f91c99fa55..d49834dc860a8b4341ddd3720fde52281f7474f7 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for SdcaModel.""" +"""Tests for SdcaModel (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index f3f1dcd98db5ae24af154d1f0851a0688d2bc611..c056a12fa5307a7e9ac4cf30e1386ddfd5cd7d75 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Proximal stochastic dual coordinate ascent optimizer for linear models.""" +# pylint: disable=line-too-long +"""Proximal stochastic dual coordinate ascent optimizer for linear models (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" +# pylint: enable=line-too-long from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -40,6 +47,7 @@ from tensorflow.python.ops import variables as var_ops from tensorflow.python.ops.nn import log_poisson_loss from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits from tensorflow.python.summary import summary +from tensorflow.python.util import deprecation __all__ = ['SdcaModel'] @@ -48,7 +56,7 @@ __all__ = ['SdcaModel'] class SdcaModel(object): """Stochastic dual coordinate ascent solver for linear models. - Loss functions supported: + Loss functions supported: * Binary logistic loss * Squared loss @@ -109,6 +117,10 @@ class SdcaModel(object): ``` """ + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, examples, variables, options): """Create a new sdca optimizer.""" diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index a001555e8f257c88a52fdb40d4181f5cd9c92e84..a28394964a12013c43d85701b5a0ab5c559afd62 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sharded mutable dense hash table.""" +"""Sharded mutable dense hash table (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division @@ -28,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import deprecation # TODO(rohanj): This should subclass Checkpointable and implement @@ -45,6 +51,10 @@ class ShardedMutableDenseHashTable(object): # TODO(andreasst): consider moving this to lookup module + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, key_dtype, value_dtype, diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py index 2b56d0fa3a8b8564b7c73a62bd99cc900d6f5c54..2d1457f9e4cc576da696be191e718814dd9ff4e5 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for sharded_mutable_dense_hashtable.py.""" +"""Tests for sharded_mutable_dense_hashtable.py (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py index 003795233ff2b28e33fc10388ef25efb63c43bb0..64730f8eed1ff9bfcd4a980dceb28abb98e39f73 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sparse feature column.""" +"""Sparse feature column (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework.ops import internal_convert_to_tensor from tensorflow.python.framework.ops import name_scope +from tensorflow.python.util import deprecation class SparseFeatureColumn(object): @@ -68,6 +74,10 @@ class SparseFeatureColumn(object): @@feature_values """ + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, example_indices, feature_indices, feature_values): """Creates a `SparseFeatureColumn` representation. diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py index 51c4f68543da2f563481cc2d35b556796616cf9d..0ae780e1a100c7dadde7196803f2ae0d4bcb2334 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for sparse_feature_column.py.""" +"""Tests for sparse_feature_column.py (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index e52fb5ab1431e086f99b4033a6216636a83bad79..3d21fb68a1452c97f7eb85491fc850d9e846266a 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools - from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,7 +26,6 @@ from tensorflow.python.ops import lookup_ops # pylint: disable=unused-import from tensorflow.python.ops.lookup_ops import FastHashSpec from tensorflow.python.ops.lookup_ops import HasherSpec -from tensorflow.python.ops.lookup_ops import HashTable from tensorflow.python.ops.lookup_ops import IdTableWithHashBuckets from tensorflow.python.ops.lookup_ops import index_table_from_file from tensorflow.python.ops.lookup_ops import index_to_string_table_from_file @@ -42,7 +39,6 @@ from tensorflow.python.ops.lookup_ops import TextFileIndex from tensorflow.python.ops.lookup_ops import TextFileInitializer from tensorflow.python.ops.lookup_ops import TextFileStringTableInitializer # pylint: enable=unused-import -from tensorflow.python.training.saver import BaseSaverBuilder from tensorflow.python.util.deprecation import deprecated @@ -91,7 +87,7 @@ def index_table_from_tensor(mapping, The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`. The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.initializer.run()` once. + `session.run(tf.tables_initializer)` or `session.run(table.init)` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -158,7 +154,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None): will throw a FailedPreconditionError. The underlying table must be initialized by calling - `tf.tables_initializer.run()` once. + `session.run(tf.tables_initializer)` once. For example: @@ -202,7 +198,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.tables_initializer.run()` or `table.initializer.run()` once. + `session.run(tf.tables_initializer)` or `session.run(table.init)` once. Elements in `mapping` cannot have duplicates, otherwise when executing the table initializer op, it will throw a `FailedPreconditionError`. @@ -257,7 +253,7 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): (an out-of-vocabulary entry) is assigned the `default_value` The underlying table must be initialized by calling - `tf.tables_initializer.run()` once. + `session.run(tf.tables_initializer)` once. For example: @@ -288,353 +284,52 @@ def index_to_string(tensor, mapping, default_value="UNK", name=None): return table.lookup(tensor) -class MutableHashTable(LookupInterface): - """A generic mutable hash table implementation. - - Data can be inserted by calling the insert method and removed by calling the - remove method. It does not support initialization via the init method. +class HashTable(InitializableLookupTableBase): + """A generic hash table implementation. Example usage: ```python - table = tf.contrib.lookup.MutableHashTable(key_dtype=tf.string, - value_dtype=tf.int64, - default_value=-1) - sess.run(table.insert(keys, values)) - out = table.lookup(query_keys) + table = tf.HashTable( + tf.KeyValueTensorInitializer(keys, values), -1) + out = table.lookup(input_tensor) + table.init.run() print(out.eval()) ``` """ - def __init__(self, - key_dtype, - value_dtype, - default_value, - shared_name=None, - name="MutableHashTable", - checkpoint=True): - """Creates an empty `MutableHashTable` object. + def __init__(self, initializer, default_value, shared_name=None, name=None): + """Creates a non-initialized `HashTable` object. - Creates a table, the type of its keys and values are specified by key_dtype - and value_dtype, respectively. + Creates a table, the type of its keys and values are specified by the + initializer. + Before using the table you will have to initialize it. After initialization + the table will be immutable. Args: - key_dtype: the type of the key tensors. - value_dtype: the type of the value tensors. + initializer: The table initializer to use. See `HashTable` kernel for + supported key and value types. default_value: The value to use if a key is missing in the table. - shared_name: If non-empty, this table will be shared under - the given name across multiple sessions. + shared_name: If non-empty, this table will be shared under the given name + across multiple sessions. name: A name for the operation (optional). - checkpoint: if True, the contents of the table are saved to and restored - from checkpoints. If `shared_name` is empty for a checkpointed table, it - is shared using the table node name. Returns: - A `MutableHashTable` object. - - Raises: - ValueError: If checkpoint is True and no name was specified. + A `HashTable` object. """ - self._default_value = ops.convert_to_tensor(default_value, - dtype=value_dtype) - self._value_shape = self._default_value.get_shape() - self._checkpoint = checkpoint - self._key_dtype = key_dtype - self._value_dtype = value_dtype - self._name = name - - if context.executing_eagerly() and shared_name is None: - # TODO(allenl): This will leak memory due to kernel caching by the - # shared_name attribute value (but is better than the alternative of - # sharing everything by default when executing eagerly; hopefully creating - # tables in a loop is uncommon). - shared_name = "table_%d" % (ops.uid(),) + self._initializer = initializer + self._default_value = default_value self._shared_name = shared_name - super(MutableHashTable, self).__init__(key_dtype, value_dtype) - - self._resource_handle = self.create_resource() - if checkpoint: - saveable = MutableHashTable._Saveable(self, name) - if not context.executing_eagerly(): - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - - def create_resource(self): - # The table must be shared if checkpointing is requested for multi-worker - # training to work correctly. Use the node name if no shared_name has been - # explicitly specified. - use_node_name_sharing = self._checkpoint and self._shared_name is None - if self._default_value.get_shape().ndims == 0: - table_ref = gen_lookup_ops.mutable_hash_table_v2( - shared_name=self._shared_name, - use_node_name_sharing=use_node_name_sharing, - key_dtype=self._key_dtype, - value_dtype=self._value_dtype, - name=self._name) - else: - table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( - shared_name=self._shared_name, - use_node_name_sharing=use_node_name_sharing, - key_dtype=self._key_dtype, - value_dtype=self._value_dtype, - value_shape=self._default_value.get_shape(), - name=self._name) - - if context.executing_eagerly(): - self._table_name = None - else: - self._table_name = table_ref.op.name.split("/")[-1] - return table_ref - - @property - def name(self): - return self._table_name - - def size(self, name=None): - """Compute the number of elements in this table. - - Args: - name: A name for the operation (optional). - - Returns: - A scalar tensor containing the number of elements in this table. - """ - with ops.name_scope(name, "%s_Size" % self.name, - [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - return gen_lookup_ops.lookup_table_size_v2( - self.resource_handle, name=name) - - def remove(self, keys, name=None): - """Removes `keys` and its associated values from the table. - - If a key is not present in the table, it is silently ignored. - - Args: - keys: Keys to remove. Can be a tensor of any shape. Must match the table's - key type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - if keys.dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - - with ops.name_scope( - name, "%s_lookup_table_remove" % self.name, - (self.resource_handle, keys, self._default_value)) as name: - # pylint: disable=protected-access - op = gen_lookup_ops.lookup_table_remove_v2( - self.resource_handle, keys, name=name) - - return op - - def lookup(self, keys, name=None): - """Looks up `keys` in a table, outputs the corresponding values. - - The `default_value` is used for keys not present in the table. - - Args: - keys: Keys to look up. Can be a tensor of any shape. Must match the - table's key_dtype. - name: A name for the operation (optional). - - Returns: - A tensor containing the values in the same shape as `keys` using the - table's value type. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - with ops.name_scope( - name, "%s_lookup_table_find" % self.name, - (self.resource_handle, keys, self._default_value)) as name: - keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") - with ops.colocate_with(self.resource_handle): - values = gen_lookup_ops.lookup_table_find_v2( - self.resource_handle, keys, self._default_value, name=name) - return values - - def insert(self, keys, values, name=None): - """Associates `keys` with `values`. - - Args: - keys: Keys to insert. Can be a tensor of any shape. Must match the - table's key type. - values: Values to be associated with keys. Must be a tensor of the same - shape as `keys` and match the table's value type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` or `values` doesn't match the table data - types. - """ - with ops.name_scope(name, "%s_lookup_table_insert" % self.name, - [self.resource_handle, keys, values]) as name: - keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys") - values = ops.convert_to_tensor(values, self._value_dtype, name="values") - with ops.colocate_with(self.resource_handle): - # pylint: disable=protected-access - op = gen_lookup_ops.lookup_table_insert_v2( - self.resource_handle, keys, values, name=name) - return op - - def export(self, name=None): - """Returns tensors of all keys and values in the table. - - Args: - name: A name for the operation (optional). - - Returns: - A pair of tensors with the first tensor containing all keys and the - second tensors containing all values in the table. - """ - with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, - [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( - self.resource_handle, self._key_dtype, self._value_dtype, name=name) - return exported_keys, exported_values - - def _gather_saveables_for_checkpoint(self): - """For object-based checkpointing.""" - return {"table": functools.partial(MutableHashTable._Saveable, table=self)} - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject implementation for MutableHashTable.""" - - def __init__(self, table, name): - tensors = table.export() - specs = [ - BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), - BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") - ] - # pylint: disable=protected-access - super(MutableHashTable._Saveable, self).__init__(table, specs, name) - - def restore(self, restored_tensors, restored_shapes): - del restored_shapes # unused - # pylint: disable=protected-access - with ops.colocate_with(self.op.resource_handle): - return gen_lookup_ops.lookup_table_import_v2( - self.op.resource_handle, restored_tensors[0], restored_tensors[1]) - - -class MutableDenseHashTable(LookupInterface): - """A generic mutable hash table implementation using tensors as backing store. - - Data can be inserted by calling the insert method and removed by calling the - remove method. It does not support initialization via the init method. - - It uses "open addressing" with quadratic reprobing to resolve collisions. - Compared to `MutableHashTable` the insert, remove and lookup operations in a - `MutableDenseHashTable` are typically faster, but memory usage can be higher. - However, `MutableDenseHashTable` does not require additional memory for - temporary tensors created during checkpointing and restore operations. - - Example usage: - - ```python - table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64, - value_dtype=tf.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - - sess.run(table.insert(keys, values)) - out = table.lookup(query_keys) - print(out.eval()) - ``` - """ - - # TODO(andreasst): consider extracting common code with MutableHashTable into - # a common superclass. - def __init__(self, - key_dtype, - value_dtype, - default_value, - empty_key, - deleted_key, - initial_num_buckets=None, - shared_name=None, - name="MutableDenseHashTable", - checkpoint=True): - """Creates an empty `MutableDenseHashTable` object. - - Creates a table, the type of its keys and values are specified by key_dtype - and value_dtype, respectively. - - Args: - key_dtype: the type of the key tensors. - value_dtype: the type of the value tensors. - default_value: The value to use if a key is missing in the table. - empty_key: the key to use to represent empty buckets internally. Must not - be used in insert, remove or lookup operations. - initial_num_buckets: the initial number of buckets. - shared_name: If non-empty, this table will be shared under - the given name across multiple sessions. - name: A name for the operation (optional). - checkpoint: if True, the contents of the table are saved to and restored - from checkpoints. If `shared_name` is empty for a checkpointed table, it - is shared using the table node name. - deleted_key: the key to use to represent deleted buckets internally. Must - not be used in insert, remove or lookup operations and be different from - the empty_key. - - Returns: - A `MutableDenseHashTable` object. - - Raises: - ValueError: If checkpoint is True and no name was specified. - """ - self._default_value = ops.convert_to_tensor( - default_value, dtype=value_dtype, name="default_value") - self._key_dtype = key_dtype - self._value_dtype = value_dtype - self._initial_num_buckets = initial_num_buckets + self._name = name or "hash_table" + self._table_name = None + super(HashTable, self).__init__(default_value, initializer) self._value_shape = self._default_value.get_shape() - self._checkpoint = checkpoint - self._name = name - - self._empty_key = ops.convert_to_tensor( - empty_key, dtype=key_dtype, name="empty_key") - self._deleted_key = ops.convert_to_tensor( - deleted_key, dtype=key_dtype, name="deleted_key") - if context.executing_eagerly() and shared_name is None: - # TODO(allenl): This will leak memory due to kernel caching by the - # shared_name attribute value (but is better than the alternative of - # sharing everything by default when executing eagerly; hopefully creating - # tables in a loop is uncommon). - shared_name = "table_%d" % (ops.uid(),) - self._shared_name = shared_name - super(MutableDenseHashTable, self).__init__(key_dtype, value_dtype) - - self._resource_handle = self.create_resource() - if checkpoint: - saveable = MutableDenseHashTable._Saveable(self, name) - if not context.executing_eagerly(): - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) def create_resource(self): - # The table must be shared if checkpointing is requested for multi-worker - # training to work correctly. Use the node name if no shared_name has been - # explicitly specified. - use_node_name_sharing = self._checkpoint and self._shared_name is None - table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( - empty_key=self._empty_key, - deleted_key=self._deleted_key, + table_ref = gen_lookup_ops.hash_table_v2( shared_name=self._shared_name, - use_node_name_sharing=use_node_name_sharing, - value_dtype=self._value_dtype, - value_shape=self._value_shape, - initial_num_buckets=self._initial_num_buckets, + key_dtype=self._initializer.key_dtype, + value_dtype=self._initializer.value_dtype, name=self._name) if context.executing_eagerly(): self._table_name = None @@ -646,103 +341,6 @@ class MutableDenseHashTable(LookupInterface): def name(self): return self._table_name - def size(self, name=None): - """Compute the number of elements in this table. - - Args: - name: A name for the operation (optional). - - Returns: - A scalar tensor containing the number of elements in this table. - """ - with ops.name_scope(name, "%s_Size" % self.name, - [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - return gen_lookup_ops.lookup_table_size_v2( - self.resource_handle, name=name) - - def lookup(self, keys, name=None): - """Looks up `keys` in a table, outputs the corresponding values. - - The `default_value` is used for keys not present in the table. - - Args: - keys: Keys to look up. Can be a tensor of any shape. Must match the - table's key_dtype. - name: A name for the operation (optional). - - Returns: - A tensor containing the values in the same shape as `keys` using the - table's value type. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - with ops.name_scope(name, "%s_lookup_table_find" % self.name, - [self.resource_handle, keys]) as name: - keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") - with ops.colocate_with(self.resource_handle): - values = gen_lookup_ops.lookup_table_find_v2( - self.resource_handle, keys, self._default_value, name=name) - - return values - - def insert(self, keys, values, name=None): - """Associates `keys` with `values`. - - Args: - keys: Keys to insert. Can be a tensor of any shape. Must match the - table's key type. - values: Values to be associated with keys. Must be a tensor of the same - shape as `keys` and match the table's value type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` or `values` doesn't match the table data - types. - """ - with ops.name_scope(name, "%s_lookup_table_insert" % self.name, - [self.resource_handle, keys, values]) as name: - keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") - values = ops.convert_to_tensor( - values, dtype=self._value_dtype, name="values") - with ops.colocate_with(self.resource_handle): - op = gen_lookup_ops.lookup_table_insert_v2( - self.resource_handle, keys, values, name=name) - return op - - def remove(self, keys, name=None): - """Removes `keys` and its associated values from the table. - - If a key is not present in the table, it is silently ignored. - - Args: - keys: Keys to remove. Can be a tensor of any shape. Must match the table's - key type. - name: A name for the operation (optional). - - Returns: - The created Operation. - - Raises: - TypeError: when `keys` do not match the table data types. - """ - if keys.dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - - with ops.name_scope( - name, "%s_lookup_table_remove" % self.name, - (self.resource_handle, keys, self._default_value)) as name: - # pylint: disable=protected-access - op = gen_lookup_ops.lookup_table_remove_v2( - self.resource_handle, keys, name=name) - - return op - def export(self, name=None): """Returns tensors of all keys and values in the table. @@ -753,34 +351,15 @@ class MutableDenseHashTable(LookupInterface): A pair of tensors with the first tensor containing all keys and the second tensors containing all values in the table. """ - with ops.name_scope(name, "%s_lookup_table_export_values" % self.name, + with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]) as name: - with ops.colocate_with(self.resource_handle): - exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( - self.resource_handle, self._key_dtype, self._value_dtype, name=name) + exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( + self.resource_handle, self._key_dtype, self._value_dtype, name=name) + exported_values.set_shape(exported_keys.get_shape().concatenate( + self._value_shape)) return exported_keys, exported_values - def _gather_saveables_for_checkpoint(self): - """For object-based checkpointing.""" - return {"table": functools.partial( - MutableDenseHashTable._Saveable, table=self)} - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject implementation for MutableDenseHashTable.""" - - def __init__(self, table, name): - tensors = table.export() - specs = [ - BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), - BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") - ] - # pylint: disable=protected-access - super(MutableDenseHashTable._Saveable, self).__init__(table, specs, name) - - def restore(self, restored_tensors, restored_shapes): - del restored_shapes # unused - # pylint: disable=protected-access - with ops.colocate_with(self.op.resource_handle): - return gen_lookup_ops.lookup_table_import_v2( - self.op.resource_handle, restored_tensors[0], restored_tensors[1]) + +MutableHashTable = lookup_ops.MutableHashTable +MutableDenseHashTable = lookup_ops.MutableDenseHashTable diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5e99ef460518fa761b12533e5dc07dc252f1d582..591eabc66c49f301cf73cd912ebbef70cc9e1e3f 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -18,13 +18,10 @@ from __future__ import division from __future__ import print_function import os -import tempfile import numpy as np -import six from tensorflow.contrib import lookup from tensorflow.python.client import session -from tensorflow.python.data.experimental.ops import counter from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -36,9 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import saver from tensorflow.python.training import server_lib -from tensorflow.python.training.checkpointable import util as checkpointable class HashTableOpTest(test.TestCase): @@ -298,1240 +293,6 @@ class HashTableOpTest(test.TestCase): self.assertAllEqual([b"brain", b"salad", b"n/a"], result) -class MutableHashTableOpTest(test.TestCase): - - def testMutableHashTable(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["tarkus", "tank"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list()) - self.assertAllEqual([None], exported_values.get_shape().as_list()) - - # exported data is in the order of the internal map, i.e. undefined - sorted_keys = np.sort(exported_keys.eval()) - sorted_values = np.sort(exported_values.eval()) - self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) - self.assertAllEqual([0, 1, 2], sorted_values) - - def testSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - v0 = variables.Variable(10.0, name="v0") - v1 = variables.Variable(20.0, name="v1") - - default_val = -1 - keys = constant_op.constant(["b", "c", "d"], dtypes.string) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - - save = saver.Saver() - variables.global_variables_initializer().run() - - # Check that the parameter nodes have been initialized. - self.assertEqual(10.0, v0.eval()) - self.assertEqual(20.0, v1.eval()) - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - v0 = variables.Variable(-1.0, name="v0") - v1 = variables.Variable(-1.0, name="v1") - default_val = -1 - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - table.insert( - constant_op.constant(["a", "c"], dtypes.string), - constant_op.constant([12, 24], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - # Check that the parameter nodes have been restored. - self.assertEqual(10.0, v0.eval()) - self.assertEqual(20.0, v1.eval()) - - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["a", "b", "c", "d", "e"], - dtypes.string) - output = table.lookup(input_string) - self.assertAllEqual([-1, 0, 1, 2, -1], output.eval()) - - @test_util.run_in_graph_and_eager_modes - def testObjectSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - v0 = variables.Variable(10.0, name="v0") - v1 = variables.Variable(20.0, name="v1") - - default_val = -1 - keys = constant_op.constant(["b", "c", "d"], dtypes.string) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - - checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1) - self.evaluate([v0.initializer, v1.initializer]) - - # Check that the parameter nodes have been initialized. - self.assertEqual(10.0, self.evaluate(v0)) - self.assertEqual(20.0, self.evaluate(v1)) - - self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) - - save_path = checkpoint.save(save_prefix) - del table, checkpoint, v0, v1 - - v0 = variables.Variable(-1.0, name="v0") - v1 = variables.Variable(-1.0, name="v1") - default_val = -1 - table = lookup.MutableHashTable( - dtypes.string, dtypes.int64, default_val, name="t1", checkpoint=True) - self.evaluate(table.insert( - constant_op.constant(["a", "c"], dtypes.string), - constant_op.constant([12, 24], dtypes.int64))) - self.assertAllEqual(2, self.evaluate(table.size())) - - checkpoint = checkpointable.Checkpoint(table=table, v0=v0, v1=v1) - - # Restore the saved values in the parameter nodes. - checkpoint.restore(save_path).run_restore_ops() - # Check that the parameter nodes have been restored. - self.assertEqual(10.0, self.evaluate(v0)) - self.assertEqual(20.0, self.evaluate(v1)) - - self.assertAllEqual(3, self.evaluate(table.size())) - - input_string = constant_op.constant(["a", "b", "c", "d", "e"], - dtypes.string) - output = table.lookup(input_string) - self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) - - def testSharing(self): - # Start a server to store the table state - server = server_lib.Server( - { - "local0": ["localhost:0"] - }, protocol="grpc", start=True) - # Create two sessions sharing the same state - session1 = session.Session(server.target) - session2 = session.Session(server.target) - - table = lookup.MutableHashTable( - dtypes.int64, dtypes.string, "-", name="t1") - - # Populate the table in the first session - with session1: - self.assertAllEqual(0, table.size().eval()) - - keys = constant_op.constant([11, 12], dtypes.int64) - values = constant_op.constant(["a", "b"]) - table.insert(keys, values).run() - self.assertAllEqual(2, table.size().eval()) - - output = table.lookup(constant_op.constant([11, 12, 13], dtypes.int64)) - self.assertAllEqual([b"a", b"b", b"-"], output.eval()) - - # Verify that we can access the shared data from the second session - with session2: - self.assertAllEqual(2, table.size().eval()) - - output = table.lookup(constant_op.constant([10, 11, 12], dtypes.int64)) - self.assertAllEqual([b"-", b"a", b"b"], output.eval()) - - def testMutableHashTableOfTensors(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) - values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]], - dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["tarkus", "tank"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - self.assertAllEqual([3, 2], output.get_shape()) - - result = output.eval() - self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result) - - exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list(), - msg="Saw shape %s" % exported_keys.shape) - self.assertAllEqual([None, 2], exported_values.get_shape().as_list(), - msg="Saw shape %s" % exported_values.shape) - # exported data is in the order of the internal map, i.e. undefined - sorted_keys = np.sort(exported_keys.eval()) - sorted_values = np.sort(exported_values.eval()) - self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) - self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values) - - def testMutableHashTableExportInsert(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) - table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table1.size().eval()) - table1.insert(keys, values).run() - self.assertAllEqual(3, table1.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - expected_output = [[0, 1], [2, 3], [-1, -1]] - output1 = table1.lookup(input_string) - self.assertAllEqual(expected_output, output1.eval()) - - exported_keys, exported_values = table1.export() - self.assertAllEqual(3, exported_keys.eval().size) - self.assertAllEqual(6, exported_values.eval().size) - - # Populate a second table from the exported data - table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table2.size().eval()) - table2.insert(exported_keys, exported_values).run() - self.assertAllEqual(3, table2.size().eval()) - - # Verify lookup result is still the same - output2 = table2.lookup(input_string) - self.assertAllEqual(expected_output, output2.eval()) - - def testMutableHashTableOfTensorsInvalidShape(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - # Shape [6] instead of [3, 2] - values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Shape [2,3] instead of [3, 2] - values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Shape [2, 2] instead of [3, 2] - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Shape [3, 1] instead of [3, 2] - values = constant_op.constant([[0], [2], [4]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - table.insert(keys, values).run() - - # Valid Insert - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - def testMutableHashTableInvalidDefaultValue(self): - with self.cached_session(): - default_val = constant_op.constant([[-1, -1]], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - with self.assertRaisesOpError("Default value must be a vector"): - self.assertAllEqual(0, table.size().eval()) - - def testMutableHashTableDuplicateInsert(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery", "brain"]) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([3, 1, -1], result) - - def testMutableHashTableFindHighRank(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant( - [["brain", "salad"], ["tank", "tarkus"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2], output.get_shape()) - - result = output.eval() - self.assertAllEqual([[0, 1], [-1, -1]], result) - - def testMutableHashTableInsertHighRank(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, 1, 3, -1], result) - - def testMutableHashTableRemoveHighRank(self): - with self.test_session(): - default_val = -1 - keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["salad", "tarkus"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, -1, 3, -1], result) - - def testMutableHashTableOfTensorsFindHighRank(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], - dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant( - [["brain", "salad"], ["tank", "tarkus"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2, 3], output.get_shape()) - - result = output.eval() - self.assertAllEqual( - [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) - - def testMutableHashTableOfTensorsRemoveHighRank(self): - with self.test_session(): - default_val = constant_op.constant([-1, -1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], - dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - remove_string = constant_op.constant([["brain", "tank"]]) - table.remove(remove_string).run() - self.assertAllEqual(2, table.size().eval()) - - input_string = constant_op.constant([["brain", "salad"], - ["surgery", "tank"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2, 3], output.get_shape()) - - result = output.eval() - self.assertAllEqual( - [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result) - - def testMultipleMutableHashTables(self): - with self.cached_session() as sess: - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - - table1 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table2 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table3 = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table1.insert(keys, values).run() - table2.insert(keys, values).run() - table3.insert(keys, values).run() - - self.assertAllEqual(3, table1.size().eval()) - self.assertAllEqual(3, table2.size().eval()) - self.assertAllEqual(3, table3.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output1 = table1.lookup(input_string) - output2 = table2.lookup(input_string) - output3 = table3.lookup(input_string) - - out1, out2, out3 = sess.run([output1, output2, output3]) - self.assertAllEqual([0, 1, -1], out1) - self.assertAllEqual([0, 1, -1], out2) - self.assertAllEqual([0, 1, -1], out3) - - def testMutableHashTableWithTensorDefault(self): - with self.cached_session(): - default_val = constant_op.constant(-1, dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - def testSignatureMismatch(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - - # insert with keys of the wrong type - with self.assertRaises(ValueError): - table.insert(constant_op.constant([4, 5, 6]), values).run() - - # insert with values of the wrong type - with self.assertRaises(ValueError): - table.insert(keys, constant_op.constant(["a", "b", "c"])).run() - - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string_ref = variables.Variable("brain") - input_int64_ref = variables.Variable(-1, dtype=dtypes.int64) - variables.global_variables_initializer().run() - - # Ref types do not produce an insert signature mismatch. - table.insert(input_string_ref, input_int64_ref).run() - self.assertAllEqual(3, table.size().eval()) - - # Ref types do not produce a lookup signature mismatch. - self.assertEqual(-1, table.lookup(input_string_ref).eval()) - - # lookup with keys of the wrong type - input_string = constant_op.constant([1, 2, 3], dtypes.int64) - with self.assertRaises(ValueError): - table.lookup(input_string).eval() - - # default value of the wrong type - with self.assertRaises(TypeError): - lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK") - - def testMutableHashTableStringFloat(self): - with self.cached_session(): - default_val = -1.5 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1.1, 2.2], dtypes.float32) - table = lookup.MutableHashTable(dtypes.string, dtypes.float32, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllClose([0, 1.1, default_val], result) - - def testMutableHashTableIntFloat(self): - with self.cached_session(): - default_val = -1.0 - keys = constant_op.constant([3, 7, 0], dtypes.int64) - values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32) - table = lookup.MutableHashTable(dtypes.int64, dtypes.float32, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([7, 0, 11], dtypes.int64) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllClose([-1.2, 9.9, default_val], result) - - def testMutableHashTableInt64String(self): - with self.cached_session(): - default_val = "n/a" - keys = constant_op.constant([0, 1, 2], dtypes.int64) - values = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup.MutableHashTable(dtypes.int64, dtypes.string, - default_val) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([0, 1, 3], dtypes.int64) - output = table.lookup(input_string) - - result = output.eval() - self.assertAllEqual((b"brain", b"salad", b"n/a"), result) - - -class MutableDenseHashTableOpTest(test.TestCase): - - def testBasic(self): - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant([12, 15], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 12, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([0, -1, -1], result) - - def testBasicBool(self): - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([True, True, True, True], dtypes.bool) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.bool, - default_value=False, - empty_key=0, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant([11, 15], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 12, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([False, True, False], result) - - def testSameEmptyAndDeletedKey(self): - with self.cached_session(): - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "deleted_key"): - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=42) - self.assertAllEqual(0, table.size().eval()) - - def testLookupUnknownShape(self): - with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - placeholder_keys = array_ops.placeholder(dtypes.int64) - output = table.lookup(placeholder_keys) - self.assertAllEqual(None, output.get_shape()) - result = output.eval({placeholder_keys: [11, 12, 15]}) - self.assertAllEqual([0, 1, -1], result) - - def testMapStringToFloat(self): - with self.cached_session(): - - keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string) - values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32) - default_value = constant_op.constant(-1.5, dtypes.float32) - table = lookup.MutableDenseHashTable( - dtypes.string, - dtypes.float32, - default_value=default_value, - empty_key="", - deleted_key="$") - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant(["b", "e"]) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string) - output = table.lookup(input_string) - self.assertAllEqual([4], output.get_shape()) - - result = output.eval() - self.assertAllClose([0, -1.5, 3.3, -1.5], result) - - def testMapInt64ToFloat(self): - for float_dtype in [dtypes.float32, dtypes.float64]: - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype) - default_value = constant_op.constant(-1.5, float_dtype) - table = lookup.MutableDenseHashTable( - dtypes.int64, - float_dtype, - default_value=default_value, - empty_key=0, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - remove_string = constant_op.constant([12, 15], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4], output.get_shape()) - - result = output.eval() - self.assertAllClose([0, -1.5, 3.3, -1.5], result) - - def testVectorValues(self): - with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]], - dtypes.int64) - default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=0, - deleted_key=-1, - initial_num_buckets=4) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(4, len(table.export()[0].eval())) - - table.insert( - constant_op.constant([14], dtypes.int64), - constant_op.constant([[2, 3, 4, 5]], dtypes.int64)).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - remove_string = constant_op.constant([12, 16], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4, 4], - output.shape, - msg="Saw shape: %s" % output.shape) - - result = output.eval() - self.assertAllEqual( - [[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]], - result) - - def testVectorKeys(self): - with self.cached_session(): - keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64) - values = constant_op.constant([10, 11, 12], dtypes.int64) - empty_key = constant_op.constant([0, 3], dtypes.int64) - deleted_key = constant_op.constant([-1, -1], dtypes.int64) - default_value = constant_op.constant(-1, dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - initial_num_buckets=8) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - table.insert( - constant_op.constant([[0, 0]], dtypes.int64), - constant_op.constant([13], dtypes.int64)).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64) - table.remove(remove_string).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(8, len(table.export()[0].eval())) - - input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]], - dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4], output.get_shape()) - - result = output.eval() - self.assertAllEqual([10, -1, 12, -1], result) - - def testResize(self): - with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1, - initial_num_buckets=4) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(4, len(table.export()[0].eval())) - - keys2 = constant_op.constant([12, 99], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(4, len(table.export()[0].eval())) - - keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64) - values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64) - - table.insert(keys3, values3).run() - self.assertAllEqual(6, table.size().eval()) - self.assertAllEqual(16, len(table.export()[0].eval())) - - keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18], - dtypes.int64) - output = table.lookup(keys4) - self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], output.eval()) - - def testExport(self): - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([1, 2, 3, 4], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=100, - deleted_key=200, - initial_num_buckets=8) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - - keys2 = constant_op.constant([12, 15], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - - exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list()) - self.assertAllEqual([None], exported_values.get_shape().as_list()) - - np_keys = exported_keys.eval() - np_values = exported_values.eval() - - self.assertAllEqual(8, len(np_keys)) - self.assertAllEqual(8, len(np_values)) - - # pair up keys and values, drop extra added dimension - pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0] - # sort by key - pairs = pairs[pairs[:, 0].argsort()] - self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0], - [100, 0], [100, 0], [200, 2]], pairs) - - def testSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - default_value = -1 - empty_key = 0 - deleted_key = -1 - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=32) - - save = saver.Saver() - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - keys2 = constant_op.constant([12, 15], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=64) - table.insert( - constant_op.constant([11, 14], dtypes.int64), - constant_op.constant([12, 24], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(64, len(table.export()[0].eval())) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([-1, 0, -1, 2, 3], output.eval()) - - @test_util.run_in_graph_and_eager_modes - def testObjectSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - default_value = -1 - empty_key = 0 - deleted_key = -1 - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - save_table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=32) - - save_checkpoint = checkpointable.Checkpoint(table=save_table) - - self.assertAllEqual(0, self.evaluate(save_table.size())) - self.evaluate(save_table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(save_table.size())) - self.assertAllEqual(32, len(self.evaluate(save_table.export()[0]))) - - save_path = save_checkpoint.save(save_prefix) - del save_table, save_checkpoint - - load_table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=64) - self.evaluate(load_table.insert( - constant_op.constant([11, 14], dtypes.int64), - constant_op.constant([12, 24], dtypes.int64))) - self.assertAllEqual(2, self.evaluate(load_table.size())) - self.assertAllEqual(64, len(self.evaluate(load_table.export()[0]))) - - restore_checkpoint = checkpointable.Checkpoint(table=load_table) - - # Restore the saved values in the parameter nodes. - restore_checkpoint.restore(save_path).run_restore_ops() - - self.assertAllEqual(3, self.evaluate(load_table.size())) - self.assertAllEqual(32, len(self.evaluate(load_table.export()[0]))) - - input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) - output = load_table.lookup(input_string) - self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) - - def testVectorSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-2, -3], dtypes.int64) - default_value = constant_op.constant([-1, -2], dtypes.int64) - keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], - dtypes.int64) - values = constant_op.constant([[0, 1], [2, 3], [2, 4], [4, 5]], - dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=32) - - save = saver.Saver() - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - keys2 = constant_op.constant([[12, 13], [16, 17]], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-2, -3], dtypes.int64) - default_value = constant_op.constant([-1, -2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t1", - checkpoint=True, - initial_num_buckets=64) - table.insert( - constant_op.constant([[11, 12], [13, 15]], dtypes.int64), - constant_op.constant([[21, 22], [23, 24]], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(64, len(table.export()[0].eval())) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - input_string = constant_op.constant( - [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]], - output.eval()) - - def testVectorScalarSaveRestore(self): - save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-1, -1], dtypes.int64) - default_value = constant_op.constant(-1, dtypes.int64) - keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], - dtypes.int64) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t2", - checkpoint=True, - initial_num_buckets=32) - - save = saver.Saver() - - self.assertAllEqual(0, table.size().eval()) - table.insert(keys, values).run() - self.assertAllEqual(4, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - keys2 = constant_op.constant([[12, 13], [15, 16]], dtypes.int64) - table.remove(keys2).run() - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - val = save.save(sess, save_path) - self.assertTrue(isinstance(val, six.string_types)) - self.assertEqual(save_path, val) - - with self.session(graph=ops.Graph()) as sess: - empty_key = constant_op.constant([11, 13], dtypes.int64) - deleted_key = constant_op.constant([-1, -1], dtypes.int64) - default_value = constant_op.constant(-1, dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - name="t2", - checkpoint=True, - initial_num_buckets=64) - table.insert( - constant_op.constant([[11, 12], [13, 15]], dtypes.int64), - constant_op.constant([3, 4], dtypes.int64)).run() - self.assertAllEqual(2, table.size().eval()) - self.assertAllEqual(64, len(table.export()[0].eval())) - - save = saver.Saver() - - # Restore the saved values in the parameter nodes. - save.restore(sess, save_path) - - self.assertAllEqual(3, table.size().eval()) - self.assertAllEqual(32, len(table.export()[0].eval())) - - input_string = constant_op.constant( - [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([0, 1, -1, 3, -1], output.eval()) - - def testReprobe(self): - with self.cached_session(): - # Insert 6 keys into a table with 8 buckets. - # The values are chosen to make sure collisions occur when using GCC STL - keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64) - values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1, - initial_num_buckets=8) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(6, table.size().eval()) - - input_string = constant_op.constant([10, 11, 12, 13, 14, 19, 20, 21, 22], - dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([9], output.get_shape()) - - result = output.eval() - self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result) - - def testCustomEmptyKey(self): - with self.cached_session(): - keys = constant_op.constant([11, 0, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=12, - deleted_key=-1) - self.assertAllEqual(0, table.size().eval()) - - table.insert(keys, values).run() - self.assertAllEqual(3, table.size().eval()) - - input_string = constant_op.constant([11, 0, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = output.eval() - self.assertAllEqual([0, 1, -1], result) - - def testErrors(self): - with self.cached_session(): - table = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - - # Inserting the empty key returns an error - keys1 = constant_op.constant([11, 0], dtypes.int64) - values1 = constant_op.constant([0, 1], dtypes.int64) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "empty_key"): - table.insert(keys1, values1).run() - - # Looking up the empty key returns an error - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "empty_key"): - table.lookup(keys1).eval() - - # Inserting the deleted key returns an error - keys2 = constant_op.constant([11, -1], dtypes.int64) - values2 = constant_op.constant([0, 1], dtypes.int64) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "deleted_key"): - table.insert(keys2, values2).run() - - # Looking up the empty key returns an error - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "deleted_key"): - table.lookup(keys2).eval() - - # Arbitrary tensors of keys are not supported - keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) - values = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Expected key shape"): - table.lookup(keys).eval() - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Expected key shape"): - table.insert(keys, values).run() - - table2 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=17, - deleted_key=-1, - initial_num_buckets=12) - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Number of buckets must be"): - self.assertAllEqual(0, table2.size().eval()) - - with self.assertRaisesRegexp( - errors_impl.InvalidArgumentError, - "Empty and deleted keys must have same shape"): - table3 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=[1, 2]) - self.assertAllEqual(0, table3.size().eval()) - - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Empty and deleted keys cannot be equal"): - table4 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=42) - self.assertAllEqual(0, table4.size().eval()) - - with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "Empty and deleted keys cannot be equal"): - table5 = lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=[1, 2, 3], - deleted_key=[1, 2, 3]) - self.assertAllEqual(0, table5.size().eval()) - - class IndexTableFromFile(test.TestCase): def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): @@ -2720,64 +1481,5 @@ class IdTableWithHashBucketsTest(test.TestCase): hasher_spec=lookup.StrongHashSpec([None, 2])) -class MutableHashTableBenchmark(test.Benchmark): - - def _create_table(self): - return lookup.MutableHashTable(dtypes.int64, dtypes.float32, 0.0) - - def benchmark_single_repeated_scalar_insert_scalar(self): - table = self._create_table() - value = variables.Variable(1.0) - insert = table.insert(0, value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000) - assert sess.run(size) == 1 - - def benchmark_many_repeated_scalar_insert_scalar(self): - table = self._create_table() - c = counter.Counter().make_one_shot_iterator().get_next() - value = variables.Variable(1.0) - insert = table.insert(c, value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000) - assert sess.run(size) >= 10000 - - def benchmark_single_repeated_batch_32_insert_scalar(self): - table = self._create_table() - value = variables.Variable([1.0] * 32) - insert = table.insert(list(range(32)), value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000) - assert sess.run(size) == 32 - - def benchmark_many_repeated_batch_32_insert_scalar(self): - table = self._create_table() - c = counter.Counter().make_one_shot_iterator().get_next() - value = variables.Variable([1.0] * 32) - insert = table.insert(32 * c + list(range(32)), value) - size = table.size() - with session.Session() as sess: - sess.run(value.initializer) - self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000) - assert sess.run(size) >= 1000*32 - - -class MutableDenseHashTableBenchmark(MutableHashTableBenchmark): - - def _create_table(self): - return lookup.MutableDenseHashTable( - dtypes.int64, - dtypes.float32, - default_value=0.0, - empty_key=-1, - deleted_key=-2) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 709a042bbcefb89125f7e4cd14a0d7ecd2b53281..5ebdd0b8b50063c99e6b747c594eb99c306b4efb 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -511,7 +511,7 @@ def mean_squared_error(predictions, labels=None, weights=1.0, scope=None): predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) - losses = math_ops.square(math_ops.subtract(predictions, labels)) + losses = math_ops.squared_difference(predictions, labels) return compute_weighted_loss(losses, weights, scope=scope) diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py index de76acb51ffe985162a66c617b266f47c5216b19..f3b0e77740ff1d940fcd6d00b3482e90f6ebf952 100644 --- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py @@ -105,7 +105,8 @@ def contrastive_loss(labels, embeddings_anchor, embeddings_positive, # Get per pair distances distances = math_ops.sqrt( math_ops.reduce_sum( - math_ops.square(embeddings_anchor - embeddings_positive), 1)) + math_ops.squared_difference(embeddings_anchor, embeddings_positive), + 1)) # Add contrastive loss for the siamese network. # label here is {0,1} for neg, pos. diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index b396c527673902d61072dc9cf7d2766476be8369..af3c541dc214c30e9e59fdcca995ffc53b028df4 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -30,11 +30,13 @@ EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -# Note: The Protobuf source in `tensorflow/workspace.bzl` in TensorFlow -# 1.10 branch does not work. `make distclean` fails and blocks the build -# process. For now we're hardcoding to the version which is used by -# TensorFlow 1.9. -PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz" + +# Note: The protobuf repo needs to be cloned due to its submodules. +# These variables contain the GitHub repo and the sha, from `tensorflow/workspace.bzl`, +# from which to clone it from and checkout to. +readonly PROTOBUF_REPO="https://github.com/protocolbuffers/protobuf.git" +readonly PROTOBUF_TAG="$(grep -o 'https://github.com/protocolbuffers/protobuf/archive/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1 | awk '{print substr($0, index($0, "archive") + 8, index($0, "tar") - index($0, "archive") - 9) }')" + # TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' once # the archive has been propagated in mirror.bazel.build. RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" @@ -91,11 +93,34 @@ download_and_extract() { find "${dir}" -type f -name '*BUILD' -delete } +function clone_repository() { + local repo_url="${1}" + local destination_directory="${2}" + local commit_sha="${3}" + + if [[ -d "${destination_directory}" ]]; then + rm -rf "${destination_directory}" + fi + + git clone "${repo_url}" "${destination_directory}" + + pushd "$(pwd)" 1>/dev/null + + cd "${destination_directory}" + + if [[ -n "${commit_sha}" ]]; then + git checkout "${PROTOBUF_TAG}" + fi + + git submodule update --init + + popd 1>/dev/null +} + download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen" download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest" download_and_extract "${NSYNC_URL}" "${DOWNLOADS_DIR}/nsync" -download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf" download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2" download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d" download_and_extract "${DOUBLE_CONVERSION_URL}" "${DOWNLOADS_DIR}/double_conversion" @@ -106,6 +131,8 @@ download_and_extract "${CUB_URL}" "${DOWNLOADS_DIR}/cub/external/cub_archive" download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +clone_repository "${PROTOBUF_REPO}" "${DOWNLOADS_DIR}/protobuf" "${PROTOBUF_TAG}" + replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#static uint32x2_t p2ui_CONJ_XOR;// = vld1_u32( conj_XOR_DATA ); - Removed by scripts#' \ @@ -115,5 +142,6 @@ replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DAT # TODO(satok): Remove this once protobuf/autogen.sh is fixed. replace_by_sed 's#https://googlemock.googlecode.com/files/gmock-1.7.0.zip#http://download.tensorflow.org/deps/gmock-1.7.0.zip#' \ "${DOWNLOADS_DIR}/protobuf/autogen.sh" +cat "third_party/eigen3/gebp_neon.patch" | patch "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h" echo "download_dependencies.sh completed successfully." >&2 diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index 87c73ec1ca610cac6d63468887bc350bada5910b..8330c45cc16ffa536107e25699379bb5d9e8993b 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -36,6 +36,7 @@ tensorflow/core/protobuf/queue_runner.pb.cc tensorflow/core/protobuf/rewriter_config.pb.cc tensorflow/core/protobuf/saver.pb.cc tensorflow/core/protobuf/tensorflow_server.pb.cc +tensorflow/core/protobuf/verifier_config.pb.cc tensorflow/core/util/event.pb.cc tensorflow/core/util/memmapped_file_system.pb.cc tensorflow/core/util/saved_tensor_slice.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 4120ea52ec5255b1efce7a6ce6890fc79c1e4831..7257ac8feedfb8ed18c4d691cd85766e70a48ae8 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -37,6 +37,7 @@ tensorflow/core/protobuf/rewriter_config.pb.h tensorflow/core/protobuf/saver.pb.h tensorflow/core/protobuf/tensor_bundle.pb.h tensorflow/core/protobuf/tensorflow_server.pb.h +tensorflow/core/protobuf/verifier_config.pb.h tensorflow/core/util/event.pb.h tensorflow/core/util/memmapped_file_system.pb.h tensorflow/core/util/saved_tensor_slice.pb.h diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 655c7eefcb978d40c8bc16a23685e03ed71bfb63..2cd7d6d519a55423a96526b541845392d9ec6bc2 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -119,6 +119,7 @@ tensorflow/core/kernels/fake_quant_ops.cc tensorflow/core/kernels/fifo_queue.cc tensorflow/core/kernels/fifo_queue_op.cc tensorflow/core/kernels/fill_functor.cc +tensorflow/core/kernels/fft_ops.cc tensorflow/core/kernels/function_ops.cc tensorflow/core/kernels/fused_batch_norm_op.cc tensorflow/core/kernels/gather_functor.cc diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index f94d70db9046cec43073ab1406762aea1f28c8e3..13e3b6422d1989b0d499d8d20901d919554c630e 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -29,5 +29,6 @@ tensorflow/core/protobuf/debug.pb_text.cc tensorflow/core/protobuf/rewriter_config.pb_text.cc tensorflow/core/protobuf/saver.pb_text.cc tensorflow/core/protobuf/tensor_bundle.pb_text.cc +tensorflow/core/protobuf/verifier_config.pb_text.cc tensorflow/core/util/memmapped_file_system.pb_text.cc tensorflow/core/util/saved_tensor_slice.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index 2712e906d719e72dacb60e213205ad68895f905f..24d86d313b76343ed9450a33cf185d9c426696bb 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -43,6 +43,7 @@ tensorflow/core/protobuf/rewriter_config.proto tensorflow/core/protobuf/saver.proto tensorflow/core/protobuf/tensor_bundle.proto tensorflow/core/protobuf/tensorflow_server.proto +tensorflow/core/protobuf/verifier_config.proto tensorflow/core/util/event.proto tensorflow/core/util/memmapped_file_system.proto tensorflow/core/util/saved_tensor_slice.proto diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index 062deb74b165329d8e72efa73b9d81f4174f8831..9aabc4bec3053871e3ff6cd3a88fd76d293f48cc 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics_impl from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribution_strategy_context # TODO(nsilberman): move into metrics/python/ops/ diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py index d6a670f97b32a29129cb9ea0cd71c5a2b7597a47..e789d2cb9dfbac7b1e145be48b3f707af3fd4e18 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification_test.py +++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py @@ -291,12 +291,11 @@ class F1ScoreTest(test.TestCase): labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) - tf_predictions, tf_labels = (dataset_ops.Dataset - .from_tensor_slices((predictions, labels)) - .repeat() - .batch(batch_size) - .make_one_shot_iterator() - .get_next()) + tf_predictions, tf_labels = dataset_ops.make_one_shot_iterator( + dataset_ops.Dataset + .from_tensor_slices((predictions, labels)) + .repeat() + .batch(batch_size)).get_next() f1, f1_op = classification.f1_score(tf_labels, tf_predictions, num_thresholds=3) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 7b432f8bd20989c6d95310bcaca88d44ce3e0d1f..ece246b7c28569a551f7733daf16ee1507f9c95d 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1356,9 +1356,8 @@ def _compute_placement_auc(labels, predictions, weights, alpha, weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) / (total_0 - 1. + _EPSILON)) var_1 = ( - math_ops.reduce_sum( - weights_1 * math_ops.square(placement_values_1 - auc_1)) / - (total_1 - 1. + _EPSILON)) + math_ops.reduce_sum(weights_1 * math_ops.squared_difference( + placement_values_1, auc_1)) / (total_1 - 1. + _EPSILON)) auc_std_err = math_ops.sqrt( (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON))) diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py index 1b0383d24c0c472b4875d15c3650e37dfd2439e1..c922d0cd11fda3c51a51ceccf69798df7ce75f26 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test def _GetExampleIter(inputs): dataset = dataset_ops.Dataset.from_tensor_slices(inputs) - return dataset.make_one_shot_iterator() + return dataset_ops.make_one_shot_iterator(dataset) class FixedLossScaleManagerTest(test.TestCase): diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py index 9009df0eefec13146090ba5fc2096e71ba6eb89d..33f9a43e803ea845a25bba284e41e5a0e6228dad 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py @@ -132,7 +132,7 @@ class LossScaleOptimizerTest(test.TestCase): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) - itr = dataset.make_one_shot_iterator() + itr = dataset_ops.make_one_shot_iterator(dataset) lr = 1 opt = gd.GradientDescentOptimizer(lr) @@ -182,7 +182,7 @@ class LossScaleOptimizerTest(test.TestCase): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) - itr = dataset.make_one_shot_iterator() + itr = dataset_ops.make_one_shot_iterator(dataset) lr = 1 init_loss_scale = 8 diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 45a60d79482787df4564ae3360f8252af93c7a26..710a262f33872ada8d090d796f80dc06c2a27f84 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -53,7 +53,6 @@ The pruning library allows for specification of the following hyper parameters: | weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | | threshold_decay | float | 0.0 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | -| nbins | integer | 256 | Number of bins to use for histogram computation. Note: When running on TPUs, a large (>1024) value for `nbins` may adversely affect the training time. | | block_height|integer | 1 | Number of rows in a block for block sparse matrices| | block_width |integer | 1 | Number of cols in a block for block sparse matrices| | block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)| diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index f6b4373edd0544555dd16a373802d2feb5d674b1..9966f7cf798d206fffbaeb4d16b6500a90d113e4 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -214,7 +214,7 @@ def get_pruning_hparams(): target_sparsity=0.5, sparsity_function_begin_step=0, sparsity_function_end_step=100, - sparsity_function_exponent=3, + sparsity_function_exponent=3.0, use_tpu=False) @@ -397,28 +397,26 @@ class Pruning(object): raise ValueError('Sparsity variable undefined') sparsity = self._get_sparsity(weights.op.name) - with ops.name_scope(weights.op.name + '_pruning_ops'): abs_weights = math_ops.abs(weights) - max_value = math_ops.reduce_max(abs_weights) - cdf_fn = pruning_utils.compute_cdf_from_histogram - if self._spec.use_tpu: - cdf_fn = pruning_utils.compute_cdf - - norm_cdf = cdf_fn(abs_weights, [0.0, max_value], nbins=self._spec.nbins) - current_threshold = math_ops.multiply( - math_ops.div( - math_ops.reduce_sum( - math_ops.cast( - math_ops.less(norm_cdf, sparsity), dtypes.float32)), - float(self._spec.nbins)), max_value) - + k = math_ops.cast( + math_ops.round( + math_ops.cast(array_ops.size(abs_weights), dtypes.float32) * + (1 - sparsity)), dtypes.int32) + # Sort the entire array + values, _ = nn_ops.top_k( + array_ops.reshape(abs_weights, [-1]), k=array_ops.size(abs_weights)) + # Grab the (k-1) th value + current_threshold = array_ops.gather(values, k - 1) smoothed_threshold = math_ops.add_n([ math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay), math_ops.multiply(threshold, self._spec.threshold_decay) ]) + new_mask = math_ops.cast( - math_ops.greater(abs_weights, smoothed_threshold), dtypes.float32) + math_ops.greater_equal(abs_weights, smoothed_threshold), + dtypes.float32) + return smoothed_threshold, new_mask def _maybe_update_block_mask(self, weights, threshold): diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index 1b6da5ce2b4ebb3ea3b204c4ed12bed8db951447..835614d8822147dadb029107ae0e917cc955eef0 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -102,7 +102,7 @@ class PruningTest(test.TestCase): weights = variables.VariableV1( math_ops.linspace(1.0, 100.0, 100), name="weights") masked_weights = pruning.apply_mask(weights) - sparsity = variables.VariableV1(0.5, name="sparsity") + sparsity = variables.VariableV1(0.95, name="sparsity") p = pruning.Pruning(sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.mask_update_op() @@ -111,7 +111,7 @@ class PruningTest(test.TestCase): self.assertAllEqual(np.count_nonzero(masked_weights_val), 100) session.run(mask_update_op) masked_weights_val = masked_weights.eval() - self.assertAllEqual(np.count_nonzero(masked_weights_val), 50) + self.assertAllEqual(np.count_nonzero(masked_weights_val), 5) def _blockMasking(self, hparams, weights, expected_mask): diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index 14fc51229ab53a77e8089040e8a8576babd0fafd..8f2ba036469bd02328a831a3d1de2ffbd10f5004 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -25,16 +25,12 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -_NBINS = 256 - def weight_mask_variable(var, scope): """Create a mask for the weights. @@ -165,128 +161,6 @@ def expand_tensor(tensor, block_dims): return expanded_tensor -def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None): - """Return histogram of values. - - Given the tensor `values`, this operation returns a rank 1 histogram counting - the number of entries in `values` that fell into every bin. The bins are - equal width and determined by the arguments `value_range` and `nbins`. - - Args: - values: Numeric `Tensor`. - value_range: Shape [2] `Tensor` of same `dtype` as `values`. - values <= value_range[0] will be mapped to hist[0], - values >= value_range[1] will be mapped to hist[-1]. - nbins: Scalar `int32 Tensor`. Number of histogram bins. - dtype: dtype for returned histogram. - name: A name for this operation (defaults to 'histogram'). - - Returns: - A 1-D `Tensor` holding histogram of values. - - """ - with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope: - values = ops.convert_to_tensor(values, name='values') - values = array_ops.reshape(values, [-1]) - nbins_float = np.float32(nbins) - - # Map tensor values that fall within value_range to [0, 1]. - scaled_values = math_ops.truediv( - values - value_range[0], - value_range[1] - value_range[0], - name='scaled_values') - - # map tensor values within the open interval value_range to {0,.., nbins-1}, - # values outside the open interval will be zero or less, or nbins or more. - indices = math_ops.floor(nbins_float * scaled_values, name='indices') - - # Clip edge cases (e.g. value = value_range[1]) or "outliers." - indices = math_ops.cast( - clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32) - - return math_ops.unsorted_segment_sum( - array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope) - - -def compute_cdf_from_histogram(values, value_range, **kwargs): - """Returns the normalized cumulative distribution of the given values tensor. - - Computes the histogram and uses tf.cumsum to arrive at cdf - - Args: - values: Numeric `Tensor`. - value_range: Shape [2] `Tensor` of same `dtype` as `values`. - **kwargs: keyword arguments: nbins, name - - Returns: - A 1-D `Tensor` holding normalized cdf of values. - - """ - nbins = kwargs.get('nbins', _NBINS) - name = kwargs.get('name', None) - with ops.name_scope(name, 'cdf', [values, value_range, nbins]): - histogram = _histogram( - values, value_range, dtype=dtypes.float32, nbins=nbins) - cdf = math_ops.cumsum(histogram) - return math_ops.div(cdf, math_ops.reduce_max(cdf)) - - -def compute_cdf(values, value_range, **kwargs): - """Returns the normalized cumulative distribution of the given values tensor. - - Uses tf.while_loop to directly compute the cdf of the values. - - Args: - values: Numeric `Tensor`. - value_range: Shape [2] `Tensor` of same `dtype` as `values` - **kwargs: keyword arguments: nbins, name - - Returns: - A 1-D `Tensor` holding normalized cdf of values. - - """ - nbins = kwargs.get('nbins', _NBINS) - name = kwargs.get('name', None) - with ops.name_scope(name, 'cdf', [values, value_range, nbins]): - values = ops.convert_to_tensor(values, name='values') - nbins_float = np.float32(nbins) - - # Map tensor values that fall within value_range to [0, 1]. - scaled_values = math_ops.truediv( - values - value_range[0], - value_range[1] - value_range[0], - name='scaled_values') - - # map tensor values within the open interval value_range to {0,.., nbins-1}, - # values outside the open interval will be zero or less, or nbins or more. - indices = math_ops.floor(nbins_float * scaled_values, name='indices') - - # Clip edge cases (e.g. value = value_range[1]) or "outliers." - indices = math_ops.cast( - clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32) - - cdf = array_ops.zeros(nbins) - i = constant_op.constant(0) - - def loop_cond(loop_count, _): - return math_ops.less(loop_count, nbins) - - def loop_body(loop_count, cdf): - temp = math_ops.reduce_sum( - math_ops.cast( - math_ops.less_equal(indices, loop_count), dtypes.float32)) - cdf = math_ops.add( - cdf, - array_ops.one_hot( - loop_count, depth=nbins, on_value=temp, off_value=0.0)) - return [loop_count + 1, cdf] - - _, cdf = control_flow_ops.while_loop( - loop_cond, loop_body, [i, cdf], maximum_iterations=nbins) - - return math_ops.div(cdf, math_ops.reduce_max(cdf)) - - def factorized_pool(input_tensor, window_shape, pooling_type, diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py index d6f2bfcb6c2e2beda912eb538d8a4a0a17b486b3..b85bc413155d53cd6d53e98dae0ad626531f61eb 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py @@ -19,13 +19,9 @@ from __future__ import division from __future__ import print_function from absl.testing import parameterized -import numpy as np from tensorflow.contrib.model_pruning.python import pruning_utils -from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope @@ -33,57 +29,6 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class PruningUtilsTest(test.TestCase): - - def _compare_cdf(self, values): - abs_values = math_ops.abs(values) - max_value = math_ops.reduce_max(abs_values) - with self.cached_session(): - variables.global_variables_initializer().run() - cdf_from_histogram = pruning_utils.compute_cdf_from_histogram( - abs_values, [0.0, max_value], nbins=pruning_utils._NBINS) - cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value]) - self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval()) - - def testHistogram(self): - width = 10 - height = 10 - nbins = 100 - expected_histogram = np.full(nbins, 1.0) - init = init_ops.constant_initializer(np.linspace(0.0, 1.0, width * height)) - weights = variable_scope.get_variable( - "weights", [width, height], initializer=init) - histogram = pruning_utils._histogram( - weights, [0, 1.0], nbins, dtype=np.float32) - with self.cached_session(): - variables.global_variables_initializer().run() - computed_histogram = histogram.eval() - self.assertAllEqual(expected_histogram, computed_histogram) - - def testCDF(self): - nbins = 5 - weights = constant_op.constant([-1, 0, 1, 1.5, 2, 3, 4, 5, 10, 100]) - abs_weights = math_ops.abs(weights) - norm_cdf = pruning_utils.compute_cdf_from_histogram( - abs_weights, [0.0, 5.0], nbins=nbins) - expected_cdf = np.array([0.1, 0.4, 0.5, 0.6, 1.0], dtype=np.float32) - with self.cached_session() as sess: - variables.global_variables_initializer().run() - norm_cdf_val = sess.run(norm_cdf) - self.assertAllEqual(len(norm_cdf_val), nbins) - self.assertAllEqual(expected_cdf, norm_cdf_val) - - def testCDFEquivalence2D(self): - width = 100 - height = 100 - weights = variable_scope.get_variable("weights", shape=[width, height]) - self._compare_cdf(weights) - - def testCDFEquivalence4D(self): - weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128]) - self._compare_cdf(weights) - - @parameterized.named_parameters( ("Input_32x32_block_1x1", [32, 32], [1, 1]), # block size 6x6 diff --git a/tensorflow/contrib/mpi/mpi_server_lib.cc b/tensorflow/contrib/mpi/mpi_server_lib.cc index a31fa9ce0b3110d875689d74a41ca9f9cc85f532..e44e10af0814ba8d6d964dfc34a0470ce45c0b40 100644 --- a/tensorflow/contrib/mpi/mpi_server_lib.cc +++ b/tensorflow/contrib/mpi/mpi_server_lib.cc @@ -54,7 +54,10 @@ MPIServer::~MPIServer() { Status MPIServer::Init(ServiceInitFunction service_func, RendezvousMgrCreationFunction rendezvous_mgr_func) { - Status s = GrpcServer::Init(service_func, rendezvous_mgr_func); + GrpcServerOptions opts; + opts.service_func = service_func; + opts.rendezvous_mgr_func = rendezvous_mgr_func; + Status s = GrpcServer::Init(opts); return s; } diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD index ecac06354d2ce796f2a6021cdf2370d7c30ccab7..a7be92a35e0d62a61f7923ac61bb2c1267d039c6 100644 --- a/tensorflow/contrib/mpi_collectives/BUILD +++ b/tensorflow/contrib/mpi_collectives/BUILD @@ -52,7 +52,6 @@ tf_custom_op_library( deps = [ ":mpi_defines", ":mpi_message_proto_cc", - "//tensorflow/stream_executor:stream_executor_headers_lib", "//third_party/mpi", ], ) diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index f4ac70eb1a720c2acc3ef942f269228156749cba..12320d9e456ae93cbf95639a0c9e0c7f414f3518 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -14,6 +14,7 @@ py_library( name = "opt_py", srcs = [ "__init__.py", + "python/training/adam_gs_optimizer.py", "python/training/adamax.py", "python/training/addsign.py", "python/training/agn_optimizer.py", @@ -22,6 +23,7 @@ py_library( "python/training/external_optimizer.py", "python/training/ggt.py", "python/training/lars_optimizer.py", + "python/training/lazy_adam_gs_optimizer.py", "python/training/lazy_adam_optimizer.py", "python/training/matrix_functions.py", "python/training/model_average_optimizer.py", @@ -60,6 +62,21 @@ py_library( ], ) +py_test( + name = "adam_gs_optimizer_test", + srcs = ["python/training/adam_gs_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + py_test( name = "adamax_test", srcs = ["python/training/adamax_test.py"], @@ -148,6 +165,25 @@ py_test( ], ) +py_test( + name = "lazy_adam_gs_optimizer_test", + srcs = ["python/training/lazy_adam_gs_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + py_test( name = "lazy_adam_optimizer_test", srcs = ["python/training/lazy_adam_optimizer_test.py"], @@ -283,6 +319,9 @@ tf_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//third_party/py/numpy", ], + tags = [ + "oss_serial", + ], ) tf_py_test( diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index c7ea68efa9a13a471bba3f41d0600855793b20a2..e8fc52342ceabb47da97ca0f3c8a01e419a221a1 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import +from tensorflow.contrib.opt.python.training.adam_gs_optimizer import * from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.agn_optimizer import * @@ -28,6 +29,7 @@ from tensorflow.contrib.opt.python.training.external_optimizer import * from tensorflow.contrib.opt.python.training.lars_optimizer import * from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.lazy_adam_gs_optimizer import * from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * @@ -44,12 +46,14 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'AdaMaxOptimizer', + 'AdamGSOptimizer', 'PowerSignOptimizer', 'AddSignOptimizer', 'DelayCompensatedGradientDescentOptimizer', 'DropStaleGradientOptimizer', 'ExternalOptimizerInterface', 'LARSOptimizer', + 'LazyAdamGSOptimizer', 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0b149ed17533adff3bd7cd8fd8ff94d171f72911 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer.py @@ -0,0 +1,223 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adam rewrite to use global step for computing beta1 & beta2 accumulation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("train.AdamOptimizer") +class AdamGSOptimizer(optimizer.Optimizer): + """Optimizer that implements the Adam algorithm. + + See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) + ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). + """ + + def __init__(self, + global_step=0, + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8, + use_locking=False, + name="Adam"): + r"""Construct a new Adam optimizer. + + Branched from tf.train.AdamOptimizer. The only difference is to pass + global step for computing beta1 and beta2 accumulators, instead of having + optimizer keep its own independent beta1 and beta2 accumulators as non-slot + variables. + + Initialization: + + $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$ + $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ + $$t := 0 \text{(Initialize timestep)}$$ + + The update rule for `variable` with gradient `g` uses an optimization + described at the end of section2 of the paper: + + $$t := t + 1$$ + $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ + + $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ + $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ + $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ + + The default value of 1e-8 for epsilon might not be a good default in + general. For example, when training an Inception network on ImageNet a + current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the + formulation just before Section 2.1 of the Kingma and Ba paper rather than + the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon + hat" in the paper. + + The sparse implementation of this algorithm (used when the gradient is an + IndexedSlices object, typically because of `tf.gather` or an embedding + lookup in the forward pass) does apply momentum to variable slices even if + they were not used in the forward pass (meaning they have a gradient equal + to zero). Momentum decay (beta1) is also applied to the entire momentum + accumulator. This means that the sparse behavior is equivalent to the dense + behavior (in contrast to some momentum implementations which ignore momentum + unless a variable slice was actually used). + + Args: + global_step: tensorflow variable indicating the step. + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. The exponential decay + rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. The exponential decay + rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". @compatibility(eager) When eager execution is + enabled, `learning_rate`, `beta1`, `beta2`, and `epsilon` can each be a + callable that takes no arguments and returns the actual value to use. + This can be useful for changing these values across different + invocations of optimizer functions. @end_compatibility + """ + super(AdamGSOptimizer, self).__init__(use_locking, name) + self._lr = learning_rate + self._beta1 = beta1 + self._beta2 = beta2 + self._epsilon = epsilon + self._global_step = global_step + self._global_step_on_worker = None + + # Tensor versions of the constructor arguments, created in _prepare(). + self._lr_t = None + self._beta1_t = None + self._beta2_t = None + self._epsilon_t = None + + def _get_beta_accumulators(self): + return (math_ops.pow(self._beta1_t, self._global_step_on_worker), + math_ops.pow(self._beta2_t, self._global_step_on_worker)) + + def _create_slots(self, var_list): + # Create slots for the first and second moments. + for v in var_list: + self._zeros_slot(v, "m", self._name) + self._zeros_slot(v, "v", self._name) + + def _prepare(self): + lr = self._call_if_callable(self._lr) + beta1 = self._call_if_callable(self._beta1) + beta2 = self._call_if_callable(self._beta2) + epsilon = self._call_if_callable(self._epsilon) + + self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") + self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") + self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") + self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") + + # Performance optimization so that worker creates a copy of the global step + # to avoid overloading the parameter server holding the global step. + self._global_step_on_worker = math_ops.cast( + array_ops.identity(self._global_step) + 1, dtypes.float32) + + def _apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + return training_ops.apply_adam( + var, + m, + v, + math_ops.cast(beta1_power, var.dtype.base_dtype), + math_ops.cast(beta2_power, var.dtype.base_dtype), + math_ops.cast(self._lr_t, var.dtype.base_dtype), + math_ops.cast(self._beta1_t, var.dtype.base_dtype), + math_ops.cast(self._beta2_t, var.dtype.base_dtype), + math_ops.cast(self._epsilon_t, var.dtype.base_dtype), + grad, + use_locking=self._use_locking).op + + def _resource_apply_dense(self, grad, var): + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta1_power, beta2_power = self._get_beta_accumulators() + return training_ops.resource_apply_adam( + var.handle, + m.handle, + v.handle, + math_ops.cast(beta1_power, grad.dtype.base_dtype), + math_ops.cast(beta2_power, grad.dtype.base_dtype), + math_ops.cast(self._lr_t, grad.dtype.base_dtype), + math_ops.cast(self._beta1_t, grad.dtype.base_dtype), + math_ops.cast(self._beta2_t, grad.dtype.base_dtype), + math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), + grad, + use_locking=self._use_locking) + + def _apply_sparse_shared(self, grad, var, indices, scatter_add): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + # m_t = beta1 * m + (1 - beta1) * g_t + m = self.get_slot(var, "m") + m_scaled_g_values = grad * (1 - beta1_t) + m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) + with ops.control_dependencies([m_t]): + m_t = scatter_add(m, indices, m_scaled_g_values) + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) + v = self.get_slot(var, "v") + v_scaled_g_values = (grad * grad) * (1 - beta2_t) + v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) + with ops.control_dependencies([v_t]): + v_t = scatter_add(v, indices, v_scaled_g_values) + v_sqrt = math_ops.sqrt(v_t) + var_update = state_ops.assign_sub( + var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) + return control_flow_ops.group(*[var_update, m_t, v_t]) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared( + grad.values, + var, + grad.indices, + lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda + x, + i, + v, + use_locking=self._use_locking)) + + def _resource_scatter_add(self, x, i, v): + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + return self._apply_sparse_shared(grad, var, indices, + self._resource_scatter_add) diff --git a/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py b/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c68c965aef3729bebe7d0e0dd707c344321d9e3f --- /dev/null +++ b/tensorflow/contrib/opt/python/training/adam_gs_optimizer_test.py @@ -0,0 +1,382 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for AdamGS.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.opt.python.training import adam_gs_optimizer +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class AdamGSOptimizerTest(test.TestCase): + + def doTestSparse(self, use_resource=False): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSparse(self): + self.doTestSparse(use_resource=False) + + def testResourceSparse(self): + self.doTestSparse(use_resource=True) + + def testSparseDevicePlacement(self): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + var = variables.Variable([[1.0], [2.0]]) + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = adam_gs_optimizer.AdamGSOptimizer(3.0) + minimize_op = optimizer.minimize(gathered_sum) + variables.global_variables_initializer().run() + minimize_op.run() + + def testSparseRepeatedIndices(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + repeated_index_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update = adam_gs_optimizer.AdamGSOptimizer( + global_step=repeated_index_global_step).apply_gradients( + [(grad_repeated_index, repeated_index_update_var)], + global_step=repeated_index_global_step) + aggregated_update = adam_gs_optimizer.AdamGSOptimizer( + global_step=aggregated_global_step).apply_gradients( + [(grad_aggregated, aggregated_update_var)], + global_step=aggregated_global_step) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + self.evaluate(repeated_index_update_var)) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step_%d" % i) + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step, + learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertTrue(beta1_power is not None) + self.assertTrue(beta2_power is not None) + self.assertNotIn(beta1_power, opt_variables) + self.assertNotIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta2_power)) + else: + if t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertAllCloseAccordingToType( + 0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**t, self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam_gs_optimizer.AdamGSOptimizer( + global_step=global_step, learning_rate=constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adam_gs_optimizer.AdamGSOptimizer(global_step=global_step) + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testTwoSessions(self): + optimizer = adam_gs_optimizer.AdamGSOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with session.Session(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = adam_gs_optimizer.AdamGSOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two unique slot variables for v1 and v2 respectively. + self.assertEqual(4, len(set(opt.variables()))) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 6c203e5519e6a66d20e2509eca3c74eb66bf32c7..fa1a7aaff0aa59a6a64b1f0bf836a273926d785d 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import optimizer from tensorflow.python.training import saver from tensorflow.python.training import session_run_hook +from tensorflow.python.training.saving import saveable_object_util LOCAL_VARIABLE_NAME = 'local_center_variable' GLOBAL_VARIABLE_NAME = 'global_center_variable' @@ -424,7 +425,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.trainable_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) swapped_var_list = {} for key, var in var_list.items(): @@ -464,4 +465,4 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): def after_create_session(self, session, coord): """Run initialization ops""" - session.run(self._variable_init_op) \ No newline at end of file + session.run(self._variable_init_op) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8827007e4d7f6722398a8e36bd626377842d92ef --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer.py @@ -0,0 +1,114 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""LazyAdam rewrite to use global step for computing beta1 & beta2 accumulation. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.opt.python.training import adam_gs_optimizer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops + + +class LazyAdamGSOptimizer(adam_gs_optimizer.AdamGSOptimizer): + """Variant of the Adam optimizer that handles sparse updates more efficiently. + + Branched from tf.contrib.opt.LazyAdamGSOptimizer. The only difference is to + pass global step for computing beta1 and beta2 accumulators, instead of having + optimizer keep its own independent beta1 and beta2 accumulators as non-slot + variables. + + The original Adam algorithm maintains two moving-average accumulators for + each trainable variable; the accumulators are updated at every step. + This class provides lazier handling of gradient updates for sparse variables. + It only updates moving-average accumulators for sparse variable indices that + appear in the current batch, rather than updating the accumulators for all + indices. Compared with the original Adam optimizer, it can provide large + improvements in model training throughput for some applications. However, it + provides slightly different semantics than the original Adam algorithm, and + may lead to different empirical results. + """ + + def _apply_sparse(self, grad, var): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t = state_ops.scatter_update(m, grad.indices, + beta1_t * array_ops.gather(m, grad.indices) + + (1 - beta1_t) * grad.values, + use_locking=self._use_locking) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t = state_ops.scatter_update(v, grad.indices, + beta2_t * array_ops.gather(v, grad.indices) + + (1 - beta2_t) * math_ops.square(grad.values), + use_locking=self._use_locking) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + m_t_slice = array_ops.gather(m_t, grad.indices) + v_t_slice = array_ops.gather(v_t, grad.indices) + denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t + var_update = state_ops.scatter_sub(var, grad.indices, + lr * m_t_slice / denominator_slice, + use_locking=self._use_locking) + return control_flow_ops.group(var_update, m_t, v_t) + + def _resource_apply_sparse(self, grad, var, indices): + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad + m_update_op = resource_variable_ops.resource_scatter_update(m.handle, + indices, + m_t_slice) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t_slice = (beta2_t * array_ops.gather(v, indices) + + (1 - beta2_t) * math_ops.square(grad)) + v_update_op = resource_variable_ops.resource_scatter_update(v.handle, + indices, + v_t_slice) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) + var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, + indices, + var_slice) + + return control_flow_ops.group(var_update_op, m_update_op, v_update_op) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc9a02a546c8399172d0c5b58941b4d80179955 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/lazy_adam_gs_optimizer_test.py @@ -0,0 +1,402 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for LazyAdamGSOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.opt.python.training import lazy_adam_gs_optimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class LazyAdamGSOptimizerTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([False, True]) + def testSparse(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([2])) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters([False, True]) + def testSparseDevicePlacement(self, use_resource): + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.cached_session(force_gpu=test.is_gpu_available()): + # If a GPU is available, tests that all optimizer ops can be placed on + # it (i.e. they have GPU kernels). + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var = variables.Variable([[1.0], [2.0]]) + + indices = constant_op.constant([0, 1], dtype=index_dtype) + gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) + optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=3.0) + minimize_op = optimizer.minimize(gathered_sum, global_step=global_step) + variables.global_variables_initializer().run() + minimize_op.run() + + @parameterized.parameters([False, True]) + def testSparseRepeatedIndices(self, use_resource): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + if use_resource: + repeated_index_global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + else: + repeated_index_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + aggregated_global_step = variables.Variable( + array_ops.zeros([], dtypes.int64)) + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + + grad_repeated_index = ops.IndexedSlices( + constant_op.constant( + [0.1, 0.1], shape=[2, 1], dtype=dtype), + constant_op.constant([1, 1]), + constant_op.constant([2, 1])) + grad_aggregated = ops.IndexedSlices( + constant_op.constant( + [0.2], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), + constant_op.constant([2, 1])) + repeated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=repeated_index_global_step) + repeated_update = repeated_update_opt.apply_gradients( + [(grad_repeated_index, repeated_index_update_var)], + global_step=repeated_index_global_step) + aggregated_update_opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=aggregated_global_step) + aggregated_update = aggregated_update_opt.apply_gradients( + [(grad_aggregated, aggregated_update_var)], + global_step=aggregated_global_step) + variables.global_variables_initializer().run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + for _ in range(3): + repeated_update.run() + aggregated_update.run() + self.assertAllClose(aggregated_update_var.eval(), + repeated_index_update_var.eval()) + + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + global_step = resource_variable_ops.ResourceVariable( + array_ops.zeros([], dtypes.int64), name="global_step_%d" % i) + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertIsNotNone(beta1_power) + self.assertIsNotNone(beta2_power is not None) + self.assertNotIn(beta1_power, opt_variables) + self.assertNotIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + self.assertAllCloseAccordingToType( + 0.9**(t + 1), self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**(t + 1), self.evaluate(beta2_power)) + else: + if t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertAllCloseAccordingToType( + 0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType( + 0.999**t, self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.cached_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step, learning_rate=constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.cached_session(): + global_step = variables.Variable(array_ops.zeros([], dtypes.int64)) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer( + global_step=global_step) + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTwoSessions(self): + optimizer = lazy_adam_gs_optimizer.LazyAdamGSOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with self.session(graph=g): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with self.session(graph=gg): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = lazy_adam_gs_optimizer.LazyAdamGSOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertLen(set(opt.variables()), 4) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index b7fd2d2fb9db3eed15eb1cc2934199939790b1c0..bf3e5c51f78cc3ca3c7c77009c9cf428c4988953 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages from tensorflow.python.training import optimizer from tensorflow.python.training import saver +from tensorflow.python.training.saving import saveable_object_util class MovingAverageOptimizer(optimizer.Optimizer): @@ -165,7 +166,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.global_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) v_name_to_tensor = {} for k, tensor_or_list in six.iteritems(var_list): diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index 200b0d200826a6212a236680327f4daf7d07831f..8b8065c678e11e8fc237e71cf1d392ced5c22ada 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -59,6 +59,23 @@ class DecoupledWeightDecayExtension(object): Note that this extension decays weights BEFORE applying the update based on the gradient, i.e. this extension only has the desired behaviour for optimizers which do not depend on the value of'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + schedule = tf.train.piecewise_constant(tf.train.get_global_step(), + [10000, 15000], [1e-0, 1e-1, 1e-2]) + lr = 1e-1 * schedule() + wd = lambda: 1e-4 * schedule() + + # ... + + optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr, + weight_decay=wd, + momentum=0.9, + use_nesterov=True) + ``` """ def __init__(self, weight_decay, **kwargs): diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py index 248ffb1f7eb5dc27112ddf9b8670344904065ed0..1b7800f324b908e3c88fe90d31a2a08cbbd5ccf2 100644 --- a/tensorflow/contrib/optimizer_v2/adam.py +++ b/tensorflow/contrib/optimizer_v2/adam.py @@ -36,7 +36,7 @@ class AdamOptimizer(optimizer_v2.OptimizerV2): def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="Adam"): - """Construct a new Adam optimizer. + r"""Construct a new Adam optimizer. Initialization: diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 72019b31540a943582ebb4699013d9dcfc10769f..0243927ce44aec626973744507e75b20a42253e9 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -48,7 +48,7 @@ from tensorflow.python.training.checkpointable import tracking from tensorflow.python.training.checkpointable import util -class NonLayerCheckpointable(tracking.Checkpointable): +class NonLayerCheckpointable(tracking.AutoCheckpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() @@ -440,7 +440,7 @@ class CheckpointingTests(test.TestCase): def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = tracking.Checkpointable() + root = tracking.AutoCheckpointable() root.var = util.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) @@ -463,7 +463,7 @@ class CheckpointingTests(test.TestCase): 14.)) slots_path = util.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) - new_root = tracking.Checkpointable() + new_root = tracking.AutoCheckpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = util.CheckpointableSaver( @@ -508,7 +508,7 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = tracking.Checkpointable() + obj = tracking.AutoCheckpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) @@ -526,7 +526,7 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = tracking.Checkpointable() + obj = tracking.AutoCheckpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 73a556f0b299614b098ceef0fb9d32f148227b03..1323ed014c9e51e273491694fa44a8e36cc723d0 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -25,6 +25,7 @@ import abc import six from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -36,7 +37,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribution_strategy_context as distribute_ctx from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable @@ -843,8 +843,7 @@ class OptimizerV2(optimizer_v1.Optimizer): scale_loss_by_num_replicas = ( distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN) if scale_loss_by_num_replicas: - num_replicas = \ - distribute_ctx.get_distribution_strategy().num_replicas_in_sync + num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync if num_replicas > 1: loss_value *= 1. / num_replicas return loss_value @@ -997,10 +996,10 @@ class OptimizerV2(optimizer_v1.Optimizer): with ops.control_dependencies([update_ops]): finish_updates = distribution.extended.update_non_slot( non_slot_devices, finish, group=False) - # We said grouped=False, which means finish_updates is always a list. - # It will be [None] when finish() returns None. - if finish_updates == [None]: - finish_updates = [update_ops] + # We said group=False, which means finish_updates is always a tuple. + # It will be (None,) when finish() returns None. + if finish_updates == (None,): + finish_updates = (update_ops,) # Update `global_step` (if any). if global_step is None: diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py index dd7f2f44055a2e48e8a48d01c1da3a8e7513255d..2fc0b5ea4de2332ff3bf32f9a12a15eee566d5c4 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import gradients_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables @@ -71,7 +71,7 @@ class OptimizerTest(test.TestCase): opt_op = sgd_op.minimize( cost, global_step, [var0, var1], - aggregation_method=gradients_impl.AggregationMethod. + aggregation_method=gradients_util.AggregationMethod. EXPERIMENTAL_ACCUMULATE_N) variables.global_variables_initializer().run() diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index d50b52b8ff1ce8188ab52c6968d716378efd9daa..53a3bc63e1d770b451846c45370fdee9ffa72d70 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -42,6 +42,7 @@ py_library( name = "saved_model_predictor", srcs = ["saved_model_predictor.py"], srcs_version = "PY2AND3", + visibility = ["//learning/brain/contrib/learn/tpu:__subpackages__"], deps = [ ":base_predictor", "//tensorflow/contrib/saved_model:saved_model_py", diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py index 17b69c7b35dce130c45ab0aadb28be330b4bfb88..c8524e9871864e0b4fffbd549d1fe347714f072a 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py @@ -84,7 +84,10 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): values = field_dict[field.name] self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) - fd = field.value.DESCRIPTOR.fields_by_name[field.name] + if 'ext_value' in field.name: + fd = test_example_pb2.PrimitiveValue() + else: + fd = field.value.DESCRIPTOR.fields_by_name[field.name] # Values has the same shape as the input plus an extra # dimension for repeats. @@ -92,13 +95,16 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): # Nested messages are represented as TF strings, requiring # some special handling. - if field.name == 'message_value': + if field.name == 'message_value' or 'ext_value' in field.name: vs = [] for buf in values.flat: msg = test_example_pb2.PrimitiveValue() msg.ParseFromString(buf) vs.append(msg) - evs = getattr(field.value, field.name) + if 'ext_value' in field.name: + evs = field.value.Extensions[test_example_pb2.ext_value] + else: + evs = getattr(field.value, field.name) if len(vs) != len(evs): self.fail('Field %s decoded %d outputs, expected %d' % (fd.name, len(vs), len(evs))) @@ -223,7 +229,8 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): sanitize=False, force_disordered=True) - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + @parameterized.named_parameters( + *test_base.ProtoOpTestBase.named_parameters(extension=False)) def testPacked(self, case): # Now try with the packed serialization. # @@ -235,8 +242,7 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): # Note: float_format='.17g' is necessary to ensure preservation of # doubles and floats in text format. text_format.Parse( - text_format.MessageToString( - value, float_format='.17g'), + text_format.MessageToString(value, float_format='.17g'), test_example_pb2.PackedTestValue()).SerializeToString() for value in case.values ] diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py index 01b3ccc7fd3918c4ff910281289e31177e5a8097..5ec681ff55dbd18580761bb23e7017cfc9767b89 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py @@ -15,9 +15,6 @@ # ============================================================================= """Table-driven test for encode_proto op. -This test is run once with each of the *.TestCase.pbtxt files -in the test directory. - It tests that encode_proto is a lossless inverse of decode_proto (for the specified fields). """ @@ -145,7 +142,8 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): # loss of packing in the encoding). self.assertEqual(in_buf, out_buf) - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + @parameterized.named_parameters( + *test_base.ProtoOpTestBase.named_parameters(extension=False)) def testRoundtrip(self, case): in_bufs = [value.SerializeToString() for value in case.values] @@ -154,7 +152,8 @@ class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): return self._testRoundtrip( in_bufs, 'tensorflow.contrib.proto.TestValue', case.fields) - @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + @parameterized.named_parameters( + *test_base.ProtoOpTestBase.named_parameters(extension=False)) def testRoundtripPacked(self, case): # Now try with the packed serialization. # We test the packed representations by loading the same test cases using diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py index 2950c7dfdc59a11ba7d2c07d8406bd4af26b5bd9..1a636486a1765ad9544b5cb5e52961cc47f92950 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py @@ -38,17 +38,18 @@ class ProtoOpTestBase(test.TestCase): ct.cdll.LoadLibrary(lib) @staticmethod - def named_parameters(): - return ( - ("defaults", ProtoOpTestBase.defaults_test_case()), - ("minmax", ProtoOpTestBase.minmax_test_case()), - ("nested", ProtoOpTestBase.nested_test_case()), - ("optional", ProtoOpTestBase.optional_test_case()), - ("promote", ProtoOpTestBase.promote_test_case()), - ("ragged", ProtoOpTestBase.ragged_test_case()), - ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), - ("simple", ProtoOpTestBase.simple_test_case()), - ) + def named_parameters(extension=True): + parameters = [("defaults", ProtoOpTestBase.defaults_test_case()), + ("minmax", ProtoOpTestBase.minmax_test_case()), + ("nested", ProtoOpTestBase.nested_test_case()), + ("optional", ProtoOpTestBase.optional_test_case()), + ("promote", ProtoOpTestBase.promote_test_case()), + ("ragged", ProtoOpTestBase.ragged_test_case()), + ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), + ("simple", ProtoOpTestBase.simple_test_case())] + if extension: + parameters.append(("extension", ProtoOpTestBase.extension_test_case())) + return parameters @staticmethod def defaults_test_case(): @@ -399,6 +400,21 @@ class ProtoOpTestBase(test.TestCase): field.value.bool_value.append(True) return test_case + @staticmethod + def extension_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + message_value = value.Extensions[test_example_pb2.ext_value].add() + message_value.double_value = 23.5 + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = test_example_pb2.ext_value.full_name + field.dtype = types_pb2.DT_STRING + message_value = field.value.Extensions[test_example_pb2.ext_value].add() + message_value.double_value = 23.5 + return test_case + @staticmethod def simple_test_case(): test_case = test_example_pb2.TestCase() diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto index 674d881220a1113631def47c5111e3ef401b99f3..b1ce66de4feb9c6666ca9ccf39403b4e12840fcf 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto +++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto @@ -61,6 +61,8 @@ message TestValue { optional sfixed64 sfixed64_value_with_default = 32 [default = 11]; optional sint32 sint32_value_with_default = 33 [default = 12]; optional sint64 sint64_value_with_default = 34 [default = 13]; + + extensions 100 to 199; } // A PackedTestValue looks exactly the same as a TestValue in the text format, @@ -68,7 +70,7 @@ message TestValue { // by loading the same test cases using this definition instead of TestValue. // // NOTE: This definition must be kept in sync with TestValue in every way except -// the packed=true declaration. +// the packed=true declaration and the lack of extensions. message PackedTestValue { repeated double double_value = 1 [packed = true]; repeated float float_value = 2 [packed = true]; @@ -132,6 +134,10 @@ message ExtraFields { optional bool bool_value = 1777; } +extend TestValue { + repeated PrimitiveValue ext_value = 100; +} + // The messages below are for yet-to-be created tests. message EnumValue { diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index b35c4fde1a2c704880e023a0c3ac1e0766493514..b67e68ea96a15f94e62050c92405eec4fe4be70f 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -202,8 +202,9 @@ py_test( py_test( name = "quantize_parameterized_test", - size = "large", + size = "medium", srcs = ["python/quantize_parameterized_test.py"], + shard_count = 4, srcs_version = "PY2AND3", # TODO(b/118839526): Re-enable msan test. tags = [ diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index 9085d9fa719520ac84ef6f8e07d7fa335bef5605..b335e1af69b7b2e6020f8e745c43bb1bdc95a62d 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -8,9 +8,9 @@ for both training and inference. There are two aspects to this: For efficient inference, TensorFlow combines batch normalization with the preceding convolutional and fully-connected layers prior to quantization by -[folding batch norm layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/python/fold_batch_norms.py){:.external}. +[folding batch norm layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/python/fold_batch_norms.py){:.external}. -The quantization error is modeled using [fake quantization](../api_guides/python/array_ops.md#Fake_quantization) +The quantization error is modeled using [fake quantization](../../api_guides/python/array_ops.md#Fake_quantization) nodes to simulate the effect of quantization in the forward and backward passes. The forward-pass models quantization, while the backward-pass models quantization as a straight-through estimator. Both the forward- and backward-pass simulate the quantization @@ -105,12 +105,12 @@ toco \ --std_value=127.5 --mean_value=127.5 ``` -See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../lite/). +See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../../lite/). ## Quantized accuracy results -The following are results of trainiing some popular CNN models (Mobilenet-v1, +The following are results of training some popular CNN models (Mobilenet-v1, Mobilenet-v2, and Inception-v3) using this tool:

diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 8619708cdaecd78bcc7de0e8e0cbf2baa11bf6a2..39082cacf9770619cf5fb529ac9a0aad6e955c6d 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -224,8 +224,8 @@ def MovingAvgQuantize(inputs, None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope: scope.set_partitioner(None) input_shape = inputs.get_shape() - input_dim = len(input_shape) if per_channel: + input_dim = len(input_shape) # Only support quantizing 1-, 2- and 4-dimensional tensors. assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' ' scope: %s' % (input_shape, name_prefix)) diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py index 36d2af94e059cdc75b758bbf607d26c4e1ee73e9..c636c90d23a0f5a6de9d14085c824283cb41f6ca 100644 --- a/tensorflow/contrib/quantize/python/quant_ops_test.py +++ b/tensorflow/contrib/quantize/python/quant_ops_test.py @@ -63,6 +63,12 @@ class QuantOpsTest(googletest.TestCase): self.assertAlmostEqual(min_value, -0.5, delta=1e-3) self.assertAlmostEqual(max_value, 0.5, delta=1e-3) + def testMovingAvgQuantizeTrainingAssignNoShape(self): + min_value, max_value = self._GetMinMaxValues( + quant_ops.MovingAvgQuantize, [[-1, 1], [0, 0]], shape=None) + self.assertAlmostEqual(min_value, -0.5, delta=1e-3) + self.assertAlmostEqual(max_value, 0.5, delta=1e-3) + def testMovingAvgSymmetricQuantizeTrainingAssign(self): min_value, max_value = self._GetMinMaxValues( quant_ops.MovingAvgQuantize, [[-1, 0.5], [0, 0]], symmetric=True) @@ -109,10 +115,10 @@ class QuantOpsTest(googletest.TestCase): is_training=True, vars_collection=_MIN_MAX_VARS) - def _GetMinMaxValues(self, quantize_fn, input_values, **kwds): + def _GetMinMaxValues(self, quantize_fn, input_values, shape=(2), **kwds): g = ops.Graph() with session.Session(graph=g) as sess: - x = array_ops.placeholder(dtypes.float32, shape=[2]) + x = array_ops.placeholder(dtypes.float32, shape=shape) y = quantize_fn( x, init_min=0.0, diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 21d1b1213090273b5abd8e012f8711db98c94347..7c973fe597181b822e617db1f85a08f1b678e26f 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -685,7 +685,7 @@ def _InsertQuantOp(context, [1; 2^bits - 1] or wide range [0; 2^bits - 1]. producer_scope: The restriction of producer scope. If not None, the new op will be inserted only when the producer is in this scope. - consumer_scope: The restriction of producer scope. If not None, the new op + consumer_scope: The restriction of consumer scope. If not None, the new op will be inserted only when all the consumers are in this scope. Raises: ValueError: When producer operation is not directly connected to the diff --git a/tensorflow/contrib/rate/BUILD b/tensorflow/contrib/rate/BUILD index c461a7145e27c4238161cec989448be807acd543..76db9aecf615d0a94f65cd7ea799db245828db1c 100644 --- a/tensorflow/contrib/rate/BUILD +++ b/tensorflow/contrib/rate/BUILD @@ -34,6 +34,11 @@ py_test( name = "rate_test", size = "small", srcs = ["rate_test.py"], + tags = [ + "manual", # TODO(b/120555555) + "no_oss", # TODO(b/120555555) + "notap", # TODO(b/120555555) + ], deps = [ ":rate", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md index 79b015a9163f5727caa40b54579c71e57621c92f..d1c41e4c0a11028765c9fc0dc345cb29453baa31 100644 --- a/tensorflow/contrib/receptive_field/README.md +++ b/tensorflow/contrib/receptive_field/README.md @@ -185,5 +185,4 @@ Effective padding (vertical) = 1482 ## Authors -André Araujo (github id: andrefaraujo) and Mark Sandler (github id: -marksandler) +André Araujo (@andrefaraujo) and Mark Sandler (@marksandler) diff --git a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py index d6fdd12bbe37fb0e0cb12f1d0adc3fce29b19e8a..72f98ccc32e945b48b5f1b570bcca323a5b5f48a 100644 --- a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py +++ b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Computes Receptive Field (RF) information given a graph protobuf. - -For an example of usage, see accompanying file compute_rf.sh -""" +"""Computes Receptive Field (RF) information given a graph protobuf.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py index a298b4d49038468299b58140758c69675368e855..325929a5937ac60a6134fae064e7633a4c57473d 100644 --- a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py +++ b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py @@ -16,8 +16,6 @@ The receptive field (and related parameters) for the different models are printed to stdout, and may also optionally be written to a CSV file. - -For an example of usage, see rf_benchmark.sh """ from __future__ import absolute_import @@ -262,11 +260,11 @@ def _model_rf(graphdef, information will be computed. model_type: Type of model to be used, used only for printing purposes. csv_writer: A CSV writer for RF parameters, which is used if it is not None. - input_resolution: Input resolution to use when computing RF - parameters. This is important for the case where padding can only be - defined if the input resolution is known, which may happen if using SAME - padding. This is assumed the resolution for both height and width. If - None, we consider the resolution is unknown. + input_resolution: Input resolution to use when computing RF parameters. This + is important for the case where padding can only be defined if the input + resolution is known, which may happen if using SAME padding. This is + assumed the resolution for both height and width. If None, we consider the + resolution is unknown. """ for desired_end_point_key in desired_end_point_keys: print('- %s:' % desired_end_point_key) @@ -283,10 +281,10 @@ def _model_rf(graphdef, if (receptive_field_x == receptive_field_y) and ( effective_stride_x == effective_stride_y) and ( effective_padding_x == effective_padding_y): - print('Receptive field size = %5s, effective stride = %5s, effective ' - 'padding = %5s' % (str(receptive_field_x), - str(effective_stride_x), - str(effective_padding_x))) + print( + 'Receptive field size = %5s, effective stride = %5s, effective ' + 'padding = %5s' % (str(receptive_field_x), str(effective_stride_x), + str(effective_padding_x))) else: print('Receptive field size: horizontal = %5s, vertical = %5s. ' 'Effective stride: horizontal = %5s, vertical = %5s. Effective ' @@ -362,9 +360,8 @@ def _process_model_rf(model_type='resnet_v1_50', defined if the input resolution is known, which may happen if using SAME padding. The entries in the list are assumed the resolution for both height and width. If one of the elements in the list is None, we consider - it to mean that the resolution is unknown. If the list itself is None, - we use the default list [None, 224, 321]. - + it to mean that the resolution is unknown. If the list itself is None, we + use the default list [None, 224, 321]. """ # Process default value for this list. if input_resolutions is None: @@ -477,8 +474,8 @@ def _mobilenet_v1_rf(csv_writer=None): csv_writer: A CSV writer for RF parameters, which is used if it is not None. """ for model_type in _SUPPORTED_MOBILENETV1_VARIANTS: - with slim.arg_scope( - [slim.batch_norm, slim.dropout], is_training=False) as arg_sc: + with slim.arg_scope([slim.batch_norm, slim.dropout], + is_training=False) as arg_sc: _process_model_rf(model_type, csv_writer, arg_sc) diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py index 0e3c46f17d2e2a277418d39e31927db73a509670..92ae1021bc8f8fbf19ca7f7cbe208ecea18128e8 100644 --- a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py +++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py @@ -27,7 +27,8 @@ from tensorflow.python.platform import tf_logging as logging _UNCHANGED_RF_LAYER_OPS = [ "Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor", "FusedBatchNorm", "Identity", "Log", "Mul", "Pow", "RealDiv", "Relu", - "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2", "LRN" + "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2", "LRN", + "GreaterEqual" ] # Different ways in which padding modes may be spelled. @@ -276,11 +277,11 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False): kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_node) # Compute the padding for this node separately for each direction. total_padding_x, padding_x = _padding_size_conv_pool( - node, kernel_size_x, stride_x, input_resolution[1] - if input_resolution is not None else None) + node, kernel_size_x, stride_x, + input_resolution[1] if input_resolution is not None else None) total_padding_y, padding_y = _padding_size_conv_pool( - node, kernel_size_y, stride_y, input_resolution[0] - if input_resolution is not None else None) + node, kernel_size_y, stride_y, + input_resolution[0] if input_resolution is not None else None) elif node.op == "Pad": # Kernel and stride are simply 1 in this case. kernel_size_x = 1 @@ -294,11 +295,11 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False): kernel_size_x, kernel_size_y = _pool_kernel_size(node, name_to_node) # Compute the padding for this node separately for each direction. total_padding_x, padding_x = _padding_size_conv_pool( - node, kernel_size_x, stride_x, input_resolution[1] - if input_resolution is not None else None) + node, kernel_size_x, stride_x, + input_resolution[1] if input_resolution is not None else None) total_padding_y, padding_y = _padding_size_conv_pool( - node, kernel_size_y, stride_y, input_resolution[0] - if input_resolution is not None else None) + node, kernel_size_y, stride_y, + input_resolution[0] if input_resolution is not None else None) elif node.op in _UNCHANGED_RF_LAYER_OPS: # These nodes do not modify the RF parameters. kernel_size_x = 1 @@ -320,7 +321,7 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False): total_padding_y = None padding_y = None else: - raise ValueError("Unknown layer for operation '%s': %s" % (node.name, - node.op)) + raise ValueError( + "Unknown layer for operation '%s': %s" % (node.name, node.op)) return (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y, total_padding_x, total_padding_y) diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py index b9bd2f09761ab10a62d37e8e2580b93b9b8a4453..9127c772c75279d9c8eacc5a17680beba9247d01 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py @@ -12,12 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions to compute receptive field of a fully-convolutional network. - -Please refer to the following g3doc for detailed explanation on how this -computation is performed, and why it is important: -g3doc/photos/vision/features/delf/g3doc/rf_computation.md -""" +"""Functions to compute receptive field of a fully-convolutional network.""" from __future__ import absolute_import from __future__ import division @@ -96,8 +91,8 @@ class ReceptiveField(object): Args: y: An array of feature coordinates with shape `(..., d)`, where `d` is the number of dimensions of the coordinates. - axis: The dimensions for which to compute the input center coordinates. - If `None` (the default), compute the input center coordinates for all + axis: The dimensions for which to compute the input center coordinates. If + `None` (the default), compute the input center coordinates for all dimensions. Returns: @@ -127,8 +122,8 @@ class ReceptiveField(object): Args: x: An array of input center coordinates with shape `(..., d)`, where `d` is the number of dimensions of the coordinates. - axis: The dimensions for which to compute the feature coordinates. - If `None` (the default), compute the feature coordinates for all + axis: The dimensions for which to compute the feature coordinates. If + `None` (the default), compute the feature coordinates for all dimensions. Returns: @@ -274,14 +269,15 @@ def compute_receptive_field_from_graph_def(graph_def, continue # Get params for this layer. - (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, - padding_y, _, _) = parse_layer_parameters.get_layer_params( + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y, + _, _) = parse_layer_parameters.get_layer_params( node, name_to_node, node_info[node.name].input_size) - logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, " - "stride_x = %s, stride_y = %s, " - "padding_x = %s, padding_y = %s, input size = %s" % - (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, - padding_y, node_info[node.name].input_size)) + logging.vlog( + 3, "kernel_size_x = %s, kernel_size_y = %s, " + "stride_x = %s, stride_y = %s, " + "padding_x = %s, padding_y = %s, input size = %s" % + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, + padding_y, node_info[node.name].input_size)) if padding_x is None or padding_y is None: undefined_padding = True @@ -352,15 +348,15 @@ def compute_receptive_field_from_graph_def(graph_def, raise ValueError( "Graph is not aligned since effective stride from different " "paths is different in vertical direction") - if (rf_sizes_x[inp_name] - 1 - ) / 2 - effective_paddings_x[inp_name] != ( - rf_size_input_x - 1) / 2 - effective_padding_input_x: + if (rf_sizes_x[inp_name] - + 1) / 2 - effective_paddings_x[inp_name] != ( + rf_size_input_x - 1) / 2 - effective_padding_input_x: raise ValueError( "Graph is not aligned since center shift from different " "paths is different in horizontal direction") - if (rf_sizes_y[inp_name] - 1 - ) / 2 - effective_paddings_y[inp_name] != ( - rf_size_input_y - 1) / 2 - effective_padding_input_y: + if (rf_sizes_y[inp_name] - + 1) / 2 - effective_paddings_y[inp_name] != ( + rf_size_input_y - 1) / 2 - effective_padding_input_y: raise ValueError( "Graph is not aligned since center shift from different " "paths is different in vertical direction") diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py index 2054367f0d1461c8868e3332d82322a8a3dd38af..7e79785d2867de586f0730373d4864602ef770ae 100644 --- a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py +++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py @@ -50,13 +50,13 @@ def remote_fused_graph_execute(inputs, if default_graph_input_tensor_type_shapes: for type_shape in default_graph_input_tensor_type_shapes: type_shape_proto = info_proto.default_graph_input_tensor_shape.add() - type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0])) + type_shape_proto.dtype = dtypes.as_dtype(type_shape[0]).as_datatype_enum for dim in type_shape[1]: type_shape_proto.shape.dim.add().size = dim if default_graph_output_tensor_type_shapes: for type_shape in default_graph_output_tensor_type_shapes: type_shape_proto = info_proto.default_graph_output_tensor_shape.add() - type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0])) + type_shape_proto.dtype = dtypes.as_dtype(type_shape[0]).as_datatype_enum for dim in type_shape[1]: type_shape_proto.shape.dim.add().size = dim diff --git a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py index d8ca0eab276b39f025d018edebb78eed7a8433bb..cec4c3c23305034d167a248a637425507750064e 100644 --- a/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py +++ b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py @@ -164,6 +164,15 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected = [[[0.0], [27.62]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + expected_grad_warp = [[[0., 0.], [22.60000038, 35.20000076]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # One of (x, y) is less than 0. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -171,11 +180,21 @@ class ResamplerOpsTest(xla_test.XLATestCase): input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2, 2] + # -1 is out of bound for grad_warp. warp_data = [-1, 0.1, 0.7, 0.6] warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) expected = [[[0.0], [27.62]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + expected_grad_warp = [[[0., 0.], [22.60000038, 35.20000076]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # Both of (x, y) are greater than image size. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -183,11 +202,20 @@ class ResamplerOpsTest(xla_test.XLATestCase): input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2, 2] + # -0.1 is *inbound* for grad_warp and grad_data, 2.1 is out of bound. warp_data = [-0.1, 0.1, 1.2, 2.1] warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) expected = [[[0.0], [0.0]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.81], [0.0]], [[0.09], [0.0]]]] + expected_grad_warp = [[[10.30, 2.7], [0.0, 0.0]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + # One of (x, y) is greater than image size. for dtype in self.float_types: input_shape = [1, 2, 2, 1] @@ -200,6 +228,14 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected = [[[0.0], [0.0]]] self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + expected_grad_data = [[[[0.81], [0.81]], [[0.0], [0.08]]]] + expected_grad_warp = [[[-4.5, 9.5], [-9.9, 39.20]]] + + grad_output = np.ones([1, 2, 1], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index e124867415f94fb5052f34f50363ea718d71053b..d65d80df8073ef70d591c4ae2af99132f1c318ef 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -118,6 +118,7 @@ cuda_py_tests( "//tensorflow/python:rnn_cell", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "@absl_py//absl/testing:parameterized", ], ) @@ -226,7 +227,10 @@ tf_custom_op_library( "kernels/lstm_ops_gpu.cu.cc", "kernels/lstm_ops.h", ], - deps = ["//tensorflow/core/kernels:eigen_helpers"], + deps = [ + "//tensorflow/core/kernels:eigen_contraction_kernel", + "//tensorflow/core/kernels:eigen_helpers", + ], ) tf_gen_op_wrapper_py( @@ -248,7 +252,10 @@ tf_custom_op_library( "kernels/gru_ops_gpu.cu.cc", "kernels/gru_ops.h", ], - deps = ["//tensorflow/core/kernels:eigen_helpers"], + deps = [ + "//tensorflow/core/kernels:eigen_contraction_kernel", + "//tensorflow/core/kernels:eigen_helpers", + ], ) tf_gen_op_wrapper_py( @@ -345,6 +352,7 @@ tf_kernel_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//tensorflow/core/kernels:eigen_helpers", "//third_party/eigen3", ], @@ -380,6 +388,13 @@ py_binary( name = "checkpoint_convert", srcs = ["python/tools/checkpoint_convert.py"], srcs_version = "PY2AND3", + deps = [":checkpoint_convert_lib"], +) + +py_library( + name = "checkpoint_convert_lib", + srcs = ["python/tools/checkpoint_convert.py"], + srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_ops", @@ -398,7 +413,7 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":checkpoint_convert", + ":checkpoint_convert_lib", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python:session", diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.h b/tensorflow/contrib/rnn/kernels/blas_gemm.h index d37210d4b81203287fb633adc309688a35d093bb..12f3182a6a8878aa27ee143fa6405903e3fc4ef3 100644 --- a/tensorflow/contrib/rnn/kernels/blas_gemm.h +++ b/tensorflow/contrib/rnn/kernels/blas_gemm.h @@ -21,6 +21,10 @@ limitations under the License. #include "tensorflow/core/kernels/eigen_activations.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + namespace tensorflow { class OpKernelContext; namespace functor { diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 7d57b0413a3bb51c35e670ce3fdb2cc818f44a58..a70e806211c644c703f49610414854fe3e16a9b7 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import os +from absl.testing import parameterized import numpy as np from tensorflow.contrib import rnn as contrib_rnn @@ -31,6 +32,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers as keras_layers +from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -45,7 +48,7 @@ from tensorflow.python.training.checkpointable import util as checkpointable_uti Linear = core_rnn_cell._Linear # pylint: disable=invalid-name -class RNNCellTest(test.TestCase): +class RNNCellTest(test.TestCase, parameterized.TestCase): def testLinear(self): with self.cached_session() as sess: @@ -207,6 +210,35 @@ class RNNCellTest(test.TestCase): # Smoke test self.assertAllClose(res[0], [[0.509682, 0.509682]]) + def testSRUCellKerasRNN(self): + """Tests that SRUCell works with keras RNN layer.""" + cell = contrib_rnn_cell.SRUCell(10) + seq_input = ops.convert_to_tensor( + np.random.rand(2, 3, 5), name="seq_input", dtype=dtypes.float32) + rnn_layer = keras_layers.RNN(cell=cell) + rnn_outputs_keras = rnn_layer(seq_input) + with self.cached_session() as sess: + sess.run([variables_lib.global_variables_initializer()]) + self.assertEqual(sess.run(rnn_outputs_keras).shape, (2, 10)) + + def testSRUCellBiasType(self): + """Tests that the bias' dtype is properly set.""" + cell = contrib_rnn_cell.SRUCell(10) + cell.build((2, 3, 5)) + self.assertEqual(cell._bias.dtype, dtypes.float32_ref) + + cell = contrib_rnn_cell.SRUCell(10, dtype=dtypes.int32) + cell.build((2, 3, 5)) + self.assertEqual(cell._bias.dtype, dtypes.int32_ref) + + cell_input = ops.convert_to_tensor( + np.random.rand(2, 5), name="cell_input", dtype=dtypes.float16) + cell_state = ops.convert_to_tensor( + np.random.rand(2, 10), name="cell_state", dtype=dtypes.float16) + cell = contrib_rnn_cell.SRUCell(10) + cell(cell_input, [cell_state]) + self.assertEqual(cell._bias.dtype, dtypes.float16_ref) + def testSRUCellWithDiffSize(self): with self.cached_session() as sess: with variable_scope.variable_scope( @@ -610,58 +642,54 @@ class RNNCellTest(test.TestCase): # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) - def testResidualWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 3]) - base_cell = rnn_cell_impl.GRUCell(3) - g, m_new = base_cell(x, m) - variable_scope.get_variable_scope().reuse_variables() - wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell) - (name, dep), = wrapper_object._checkpoint_dependencies - wrapper_object.get_config() # Should not throw an error - self.assertIs(dep, base_cell) - self.assertEqual("cell", name) - - g_res, m_new_res = wrapper_object(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, g_res, m_new, m_new_res], { - x: np.array([[1., 1., 1.]]), - m: np.array([[0.1, 0.1, 0.1]]) - }) - # Residual connections - self.assertAllClose(res[1], res[0] + [1., 1., 1.]) - # States are left untouched - self.assertAllClose(res[2], res[3]) - - def testResidualWrapperWithSlice(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 5]) - m = array_ops.zeros([1, 3]) - base_cell = rnn_cell_impl.GRUCell(3) - g, m_new = base_cell(x, m) - variable_scope.get_variable_scope().reuse_variables() - - def residual_with_slice_fn(inp, out): - inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) - return inp_sliced + out - - g_res, m_new_res = rnn_cell_impl.ResidualWrapper( - base_cell, residual_with_slice_fn)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res_g, res_g_res, res_m_new, res_m_new_res = sess.run( - [g, g_res, m_new, m_new_res], { - x: np.array([[1., 1., 1., 1., 1.]]), - m: np.array([[0.1, 0.1, 0.1]]) - }) - # Residual connections - self.assertAllClose(res_g_res, res_g + [1., 1., 1.]) - # States are left untouched - self.assertAllClose(res_m_new, res_m_new_res) + @parameterized.parameters( + [rnn_cell_impl.ResidualWrapper, rnn_cell_impl.ResidualWrapperV2]) + @test_util.run_in_graph_and_eager_modes + def testResidualWrapper(self, wrapper_type): + x = ops.convert_to_tensor(np.array([[1., 1., 1.]])) + m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]])) + base_cell = rnn_cell_impl.GRUCell( + 3, kernel_initializer=init_ops.constant_initializer(0.5), + bias_initializer=init_ops.constant_initializer(0.5)) + g, m_new = base_cell(x, m) + wrapper_object = wrapper_type(base_cell) + (name, dep), = wrapper_object._checkpoint_dependencies + wrapper_object.get_config() # Should not throw an error + self.assertIs(dep, base_cell) + self.assertEqual("cell", name) + + g_res, m_new_res = wrapper_object(x, m) + self.evaluate([variables_lib.global_variables_initializer()]) + res = self.evaluate([g, g_res, m_new, m_new_res]) + # Residual connections + self.assertAllClose(res[1], res[0] + [1., 1., 1.]) + # States are left untouched + self.assertAllClose(res[2], res[3]) + + @parameterized.parameters( + [rnn_cell_impl.ResidualWrapper, rnn_cell_impl.ResidualWrapperV2]) + @test_util.run_in_graph_and_eager_modes + def testResidualWrapperWithSlice(self, wrapper_type): + x = ops.convert_to_tensor(np.array([[1., 1., 1., 1., 1.]])) + m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]])) + base_cell = rnn_cell_impl.GRUCell( + 3, kernel_initializer=init_ops.constant_initializer(0.5), + bias_initializer=init_ops.constant_initializer(0.5)) + g, m_new = base_cell(x, m) + + def residual_with_slice_fn(inp, out): + inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) + return inp_sliced + out + + g_res, m_new_res = wrapper_type( + base_cell, residual_with_slice_fn)(x, m) + self.evaluate([variables_lib.global_variables_initializer()]) + res_g, res_g_res, res_m_new, res_m_new_res = self.evaluate( + [g, g_res, m_new, m_new_res]) + # Residual connections + self.assertAllClose(res_g_res, res_g + [1., 1., 1.]) + # States are left untouched + self.assertAllClose(res_m_new, res_m_new_res) def testDeviceWrapper(self): with variable_scope.variable_scope( @@ -804,57 +832,166 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.175991, 0.175991]]) self.assertAllClose(res[1], [[0.13248, 0.13248]]) + @parameterized.parameters( + [[rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2], + [rnn_cell_impl.ResidualWrapper, rnn_cell_impl.ResidualWrapperV2]]) + @test_util.run_in_graph_and_eager_modes + def testWrapperKerasStyle(self, wrapper, wrapper_v2): + """Tests if wrapper cell is instantiated in keras style scope.""" + wrapped_cell_v2 = wrapper_v2(rnn_cell_impl.BasicRNNCell(1)) + self.assertTrue(wrapped_cell_v2._keras_style) + + wrapped_cell = wrapper(rnn_cell_impl.BasicRNNCell(1)) + self.assertFalse(wrapped_cell._keras_style) + + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2]) + @test_util.run_in_graph_and_eager_modes + def testWrapperV2VariableNames(self, wrapper): + """Tests that variables names do not depend on wrapper in RNN layer.""" + + def _rnn_input(apply_wrapper, name): + """Creates a RNN layer with/without wrapper and returns built rnn cell.""" + with base_layer.keras_style_scope(): + base_cell = rnn_cell_impl.MultiRNNCell( + [rnn_cell_impl.BasicRNNCell(1, name="basic_rnn_cell") + for _ in range(2)]) + if apply_wrapper: + rnn_cell = wrapper(base_cell) + else: + rnn_cell = base_cell + rnn_layer = keras_layers.RNN(rnn_cell, name=name) + inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32) + _ = rnn_layer(inputs) + return base_cell._cells[0] + + rnn_1 = _rnn_input(True, name="rnn_0") + rnn_2 = _rnn_input(False, name="rnn_1") + + for i, cell in enumerate([rnn_1, rnn_2]): + var_prefix = "rnn_{}/cell_0/basic_rnn_cell/".format(i) + self.assertCountEqual([v.name for v in cell.weights], + (var_prefix + "kernel:0", var_prefix + "bias:0")) + + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2]) + @test_util.run_in_graph_and_eager_modes + def testWrapperWeights(self, wrapper): + """Tests that wrapper weights contain wrapped cells weights.""" + + with base_layer.keras_style_scope(): + base_cell = rnn_cell_impl.BasicRNNCell(1, name="basic_rnn_cell") + rnn_cell = wrapper(base_cell) + rnn_layer = keras_layers.RNN(rnn_cell) + inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32) + rnn_layer(inputs) + + expected_weights = ["rnn/" + var for var in ("kernel:0", "bias:0")] + self.assertEqual(len(rnn_cell.weights), 2) + self.assertCountEqual([v.name for v in rnn_cell.weights], expected_weights) + self.assertCountEqual([v.name for v in rnn_cell.trainable_variables], + expected_weights) + self.assertCountEqual([v.name for v in rnn_cell.non_trainable_variables], + []) + self.assertCountEqual([v.name for v in rnn_cell._cell.weights], + expected_weights) + + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2]) + @test_util.run_in_graph_and_eager_modes + def testWrapperV2Caller(self, wrapper): + """Tests that wrapper V2 is using the LayerRNNCell's caller.""" + + with base_layer.keras_style_scope(): + base_cell = rnn_cell_impl.MultiRNNCell( + [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)]) + rnn_cell = wrapper(base_cell) + inputs = ops.convert_to_tensor([[1]], dtype=dtypes.float32) + state = ops.convert_to_tensor([[1]], dtype=dtypes.float32) + _ = rnn_cell(inputs, [state, state]) + weights = base_cell._cells[0].weights + self.assertLen(weights, expected_len=2) + self.assertTrue(all(["_wrapper" in v.name for v in weights])) + + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2]) + @test_util.run_in_graph_and_eager_modes + def testWrapperV2Build(self, wrapper): + cell = rnn_cell_impl.LSTMCell(10) + wrapper = wrapper(cell) + wrapper.build((1,)) + self.assertTrue(cell.built) + -class DropoutWrapperTest(test.TestCase): +@test_util.run_all_in_graph_and_eager_modes +class DropoutWrapperTest(test.TestCase, parameterized.TestCase): def _testDropoutWrapper(self, batch_size=None, time_steps=None, parallel_iterations=None, + wrapper_type=None, + scope="root", **kwargs): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - if batch_size is None and time_steps is None: - # 2 time steps, batch size 1, depth 3 - batch_size = 1 - time_steps = 2 - x = constant_op.constant( - [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32) - m = rnn_cell_impl.LSTMStateTuple( - *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32 - )] * 2) - else: - x = constant_op.constant( - np.random.randn(time_steps, batch_size, 3).astype(np.float32)) - m = rnn_cell_impl.LSTMStateTuple(*[ - constant_op. - constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32) - ] * 2) - outputs, final_state = rnn.dynamic_rnn( - cell=rnn_cell_impl.DropoutWrapper( - rnn_cell_impl.LSTMCell(3), dtype=x.dtype, **kwargs), - time_major=True, - parallel_iterations=parallel_iterations, - inputs=x, - initial_state=m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([outputs, final_state]) - self.assertEqual(res[0].shape, (time_steps, batch_size, 3)) - self.assertEqual(res[1].c.shape, (batch_size, 3)) - self.assertEqual(res[1].h.shape, (batch_size, 3)) - return res - - def testWrappedCellProperty(self): + if batch_size is None and time_steps is None: + # 2 time steps, batch size 1, depth 3 + batch_size = 1 + time_steps = 2 + x = constant_op.constant( + [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32) + m = rnn_cell_impl.LSTMStateTuple( + *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)] * 2) + else: + x = constant_op.constant( + np.random.randn(time_steps, batch_size, 3).astype(np.float32)) + m = rnn_cell_impl.LSTMStateTuple(*[ + constant_op. + constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)] * 2) + outputs, final_state = rnn.dynamic_rnn( + cell=wrapper_type( + rnn_cell_impl.LSTMCell( + 3, initializer=init_ops.constant_initializer(0.5)), + dtype=x.dtype, **kwargs), + time_major=True, + parallel_iterations=parallel_iterations, + inputs=x, + initial_state=m, + scope=scope) + self.evaluate([variables_lib.global_variables_initializer()]) + res = self.evaluate([outputs, final_state]) + self.assertEqual(res[0].shape, (time_steps, batch_size, 3)) + self.assertEqual(res[1].c.shape, (batch_size, 3)) + self.assertEqual(res[1].h.shape, (batch_size, 3)) + return res + + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperProperties(self, wrapper_type): cell = rnn_cell_impl.BasicRNNCell(10) - wrapper = rnn_cell_impl.DropoutWrapper(cell) + wrapper = wrapper_type(cell) # Github issue 15810 self.assertEqual(wrapper.wrapped_cell, cell) - - def testDropoutWrapperKeepAllConstantInput(self): + self.assertEqual(wrapper.state_size, 10) + self.assertEqual(wrapper.output_size, 10) + + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperZeroState(self, wrapper_type): + class _Cell(rnn_cell_impl.BasicRNNCell): + + def zero_state(self, batch_size=None, dtype=None): + return "wrapped_cell_zero_state" + wrapper = wrapper_type(_Cell(10)) + self.assertEqual(wrapper.zero_state(10, dtypes.float32), + "wrapped_cell_zero_state") + + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperKeepAllConstantInput(self, wrapper_type): keep = array_ops.ones([]) res = self._testDropoutWrapper( - input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) + input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, + wrapper_type=wrapper_type) true_full_output = np.array( [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) @@ -864,10 +1001,13 @@ class DropoutWrapperTest(test.TestCase): self.assertAllClose(true_full_output[1], res[1].h) self.assertAllClose(true_full_final_c, res[1].c) - def testDropoutWrapperKeepAll(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperKeepAll(self, wrapper_type): keep = variable_scope.get_variable("all", initializer=1.0) res = self._testDropoutWrapper( - input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep) + input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, + wrapper_type=wrapper_type) true_full_output = np.array( [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) @@ -877,7 +1017,9 @@ class DropoutWrapperTest(test.TestCase): self.assertAllClose(true_full_output[1], res[1].h) self.assertAllClose(true_full_final_c, res[1].c) - def testDropoutWrapperWithSeed(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperWithSeed(self, wrapper_type): keep_some = 0.5 random_seed.set_random_seed(2) ## Use parallel_iterations = 1 in both calls to @@ -889,28 +1031,32 @@ class DropoutWrapperTest(test.TestCase): output_keep_prob=keep_some, state_keep_prob=keep_some, seed=10, - parallel_iterations=1) - # Clear away the graph and the test session (which keeps variables around) - ops.reset_default_graph() - self._ClearCachedSession() + parallel_iterations=1, + wrapper_type=wrapper_type, + scope="root_1") random_seed.set_random_seed(2) res_standard_2 = self._testDropoutWrapper( input_keep_prob=keep_some, output_keep_prob=keep_some, state_keep_prob=keep_some, seed=10, - parallel_iterations=1) + parallel_iterations=1, + wrapper_type=wrapper_type, + scope="root_2") self.assertAllClose(res_standard_1[0], res_standard_2[0]) self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c) self.assertAllClose(res_standard_1[1].h, res_standard_2[1].h) - def testDropoutWrapperKeepNoOutput(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperKeepNoOutput(self, wrapper_type): keep_all = variable_scope.get_variable("all", initializer=1.0) keep_none = variable_scope.get_variable("none", initializer=1e-6) res = self._testDropoutWrapper( input_keep_prob=keep_all, output_keep_prob=keep_none, - state_keep_prob=keep_all) + state_keep_prob=keep_all, + wrapper_type=wrapper_type) true_full_output = np.array( [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) @@ -920,7 +1066,9 @@ class DropoutWrapperTest(test.TestCase): self.assertAllClose(true_full_output[1], res[1].h) self.assertAllClose(true_full_final_c, res[1].c) - def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self, wrapper_type): keep_all = variable_scope.get_variable("all", initializer=1.0) keep_none = variable_scope.get_variable("none", initializer=1e-6) # Even though we dropout state, by default DropoutWrapper never @@ -928,7 +1076,8 @@ class DropoutWrapperTest(test.TestCase): res = self._testDropoutWrapper( input_keep_prob=keep_all, output_keep_prob=keep_all, - state_keep_prob=keep_none) + state_keep_prob=keep_none, + wrapper_type=wrapper_type) true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32) true_full_output = np.array( [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], @@ -941,7 +1090,9 @@ class DropoutWrapperTest(test.TestCase): # c state of an LSTMStateTuple is NEVER modified. self.assertAllClose(true_c_state, res[1].c) - def testDropoutWrapperKeepNoInput(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperKeepNoInput(self, wrapper_type): keep_all = variable_scope.get_variable("all", initializer=1.0) keep_none = variable_scope.get_variable("none", initializer=1e-6) true_full_output = np.array( @@ -953,12 +1104,15 @@ class DropoutWrapperTest(test.TestCase): res = self._testDropoutWrapper( input_keep_prob=keep_none, output_keep_prob=keep_all, - state_keep_prob=keep_all) + state_keep_prob=keep_all, + wrapper_type=wrapper_type) self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4) self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4) self.assertGreater(np.linalg.norm(res[1].c - true_full_final_c), 1e-4) - def testDropoutWrapperRecurrentOutput(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperRecurrentOutput(self, wrapper_type): keep_some = 0.8 keep_all = variable_scope.get_variable("all", initializer=1.0) res = self._testDropoutWrapper( @@ -966,6 +1120,7 @@ class DropoutWrapperTest(test.TestCase): output_keep_prob=keep_some, state_keep_prob=keep_all, variational_recurrent=True, + wrapper_type=wrapper_type, input_size=3, batch_size=5, time_steps=7) @@ -974,13 +1129,16 @@ class DropoutWrapperTest(test.TestCase): for m in output_mask[1:]: self.assertAllClose(output_mask[0], m) - def testDropoutWrapperRecurrentStateInputAndOutput(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperRecurrentStateInputAndOutput(self, wrapper_type): keep_some = 0.9 res = self._testDropoutWrapper( input_keep_prob=keep_some, output_keep_prob=keep_some, state_keep_prob=keep_some, variational_recurrent=True, + wrapper_type=wrapper_type, input_size=3, batch_size=5, time_steps=7) @@ -1002,7 +1160,10 @@ class DropoutWrapperTest(test.TestCase): for batch_entry in state_h_mask: self.assertAllClose(batch_entry, state_h_mask[0]) - def testDropoutWrapperRecurrentStateInputAndOutputWithSeed(self): + @parameterized.parameters( + [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) + def testDropoutWrapperRecurrentStateInputAndOutputWithSeed( + self, wrapper_type): keep_some = 0.9 random_seed.set_random_seed(2347) np.random.seed(23487) @@ -1011,12 +1172,12 @@ class DropoutWrapperTest(test.TestCase): output_keep_prob=keep_some, state_keep_prob=keep_some, variational_recurrent=True, + wrapper_type=wrapper_type, input_size=3, batch_size=5, time_steps=7, - seed=-234987) - ops.reset_default_graph() - self._ClearCachedSession() + seed=-234987, + scope="root_0") random_seed.set_random_seed(2347) np.random.seed(23487) res1 = self._testDropoutWrapper( @@ -1024,10 +1185,12 @@ class DropoutWrapperTest(test.TestCase): output_keep_prob=keep_some, state_keep_prob=keep_some, variational_recurrent=True, + wrapper_type=wrapper_type, input_size=3, batch_size=5, time_steps=7, - seed=-234987) + seed=-234987, + scope="root_1") output_mask = np.abs(res0[0]) > 1e-6 for time_step in output_mask: diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index aa1d7d2b01b4595bbb03ba8e867e93db759cbd52..d7ee7fb8faacb0876218a983d68f007e1905c11e 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -29,7 +29,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.keras import initializers +from tensorflow.python.keras import layers as keras_layers from tensorflow.python.keras import testing_utils from tensorflow.python.keras import utils from tensorflow.python.ops import array_ops @@ -763,6 +765,17 @@ class RNNCellTest(test.TestCase): self.assertEqual(new_h.shape[1], num_proj) self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) + @test_util.run_in_graph_and_eager_modes + def testNASCellKerasRNN(self): + """Tests that NASCell works with keras RNN layer.""" + cell = contrib_rnn_cell.NASCell(10) + seq_input = ops.convert_to_tensor( + np.random.rand(2, 3, 5), name="seq_input", dtype=dtypes.float32) + rnn_layer = keras_layers.RNN(cell=cell) + rnn_outputs = rnn_layer(seq_input) + self.evaluate([variables.global_variables_initializer()]) + self.assertEqual(self.evaluate(rnn_outputs).shape, (2, 10)) + def testUGRNNCell(self): num_units = 2 batch_size = 3 diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 8a1c09f171e6108174671e3122d5ff4c0b236003..482e547a16be85804beec88a91fa03b053d09b27 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1462,7 +1462,7 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): return new_h, new_state -class NASCell(rnn_cell_impl.RNNCell): +class NASCell(rnn_cell_impl.LayerRNNCell): """Neural Architecture Search (NAS) recurrent network cell. This implements the recurrent cell from the paper: @@ -1475,23 +1475,28 @@ class NASCell(rnn_cell_impl.RNNCell): The class uses an optional projection layer. """ - def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None): + # NAS cell's architecture base. + _NAS_BASE = 8 + + def __init__(self, num_units, num_proj=None, use_bias=False, reuse=None, + **kwargs): """Initialize the parameters for a NAS cell. Args: - num_units: int, The number of units in the NAS cell + num_units: int, The number of units in the NAS cell. num_proj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. - use_biases: (optional) bool, If True then use biases within the cell. This + use_bias: (optional) bool, If True then use biases within the cell. This is False by default. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. + **kwargs: Additional keyword arguments. """ - super(NASCell, self).__init__(_reuse=reuse) + super(NASCell, self).__init__(_reuse=reuse, **kwargs) self._num_units = num_units self._num_proj = num_proj - self._use_biases = use_biases + self._use_bias = use_bias self._reuse = reuse if num_proj is not None: @@ -1509,6 +1514,33 @@ class NASCell(rnn_cell_impl.RNNCell): def output_size(self): return self._output_size + def build(self, inputs_shape): + input_size = tensor_shape.dimension_value( + tensor_shape.TensorShape(inputs_shape).with_rank(2)[1]) + if input_size is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + num_proj = self._num_units if self._num_proj is None else self._num_proj + + # Variables for the NAS cell. `recurrent_kernel` is all matrices multiplying + # the hiddenstate and `kernel` is all matrices multiplying the inputs. + self.recurrent_kernel = self.add_variable( + "recurrent_kernel", [num_proj, self._NAS_BASE * self._num_units]) + self.kernel = self.add_variable( + "kernel", [input_size, self._NAS_BASE * self._num_units]) + + if self._use_bias: + self.bias = self.add_variable("bias", + shape=[self._NAS_BASE * self._num_units], + initializer=init_ops.zeros_initializer) + + # Projection layer if specified + if self._num_proj is not None: + self.projection_weights = self.add_variable( + "projection_weights", [self._num_units, self._num_proj]) + + self.built = True + def call(self, inputs, state): """Run one step of NAS Cell. @@ -1535,38 +1567,20 @@ class NASCell(rnn_cell_impl.RNNCell): tanh = math_ops.tanh relu = nn_ops.relu - num_proj = self._num_units if self._num_proj is None else self._num_proj - (c_prev, m_prev) = state - dtype = inputs.dtype - input_size = inputs.get_shape().with_rank(2).dims[1] - if input_size.value is None: - raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - # Variables for the NAS cell. W_m is all matrices multiplying the - # hiddenstate and W_inputs is all matrices multiplying the inputs. - concat_w_m = vs.get_variable("recurrent_kernel", - [num_proj, 8 * self._num_units], dtype) - concat_w_inputs = vs.get_variable( - "kernel", [input_size.value, 8 * self._num_units], dtype) - - m_matrix = math_ops.matmul(m_prev, concat_w_m) - inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) - - if self._use_biases: - b = vs.get_variable( - "bias", - shape=[8 * self._num_units], - initializer=init_ops.zeros_initializer(), - dtype=dtype) - m_matrix = nn_ops.bias_add(m_matrix, b) + m_matrix = math_ops.matmul(m_prev, self.recurrent_kernel) + inputs_matrix = math_ops.matmul(inputs, self.kernel) + + if self._use_bias: + m_matrix = nn_ops.bias_add(m_matrix, self.bias) # The NAS cell branches into 8 different splits for both the hiddenstate # and the input m_matrix_splits = array_ops.split( - axis=1, num_or_size_splits=8, value=m_matrix) + axis=1, num_or_size_splits=self._NAS_BASE, value=m_matrix) inputs_matrix_splits = array_ops.split( - axis=1, num_or_size_splits=8, value=inputs_matrix) + axis=1, num_or_size_splits=self._NAS_BASE, value=inputs_matrix) # First layer layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) @@ -1598,9 +1612,7 @@ class NASCell(rnn_cell_impl.RNNCell): # Projection layer if specified if self._num_proj is not None: - concat_w_proj = vs.get_variable("projection_weights", - [self._num_units, self._num_proj], dtype) - new_m = math_ops.matmul(new_m, concat_w_proj) + new_m = math_ops.matmul(new_m, self.projection_weights) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m) return new_m, new_state @@ -2071,7 +2083,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell): conv_ndims: Convolution dimensionality (1, 2 or 3). input_shape: Shape of the input as int tuple, excluding the batch size. output_channels: int, number of output channels of the conv LSTM. - kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3). + kernel_shape: Shape of kernel as an int tuple (of size 1, 2 or 3). use_bias: (bool) Use bias in convolutions. skip_connection: If set to `True`, concatenate the input to the output of the conv LSTM. Default: `False`. @@ -2092,7 +2104,7 @@ class ConvLSTMCell(rnn_cell_impl.RNNCell): self._conv_ndims = conv_ndims self._input_shape = input_shape self._output_channels = output_channels - self._kernel_shape = kernel_shape + self._kernel_shape = list(kernel_shape) self._use_bias = use_bias self._forget_bias = forget_bias self._skip_connection = skip_connection @@ -2172,7 +2184,7 @@ def _conv(args, filter_size, num_features, bias, bias_start=0.0): Args: args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D, batch x n, Tensors. - filter_size: int tuple of filter height and width. + filter_size: int tuple of filter shape (of size 1, 2 or 3). num_features: int, number of features. bias: Whether to use biases in the convolution layer. bias_start: starting value to initialize the bias; 0 by default. @@ -2744,10 +2756,12 @@ class SRUCell(rnn_cell_impl.LayerRNNCell): name: (optional) String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. + **kwargs: Additional keyword arguments. """ - def __init__(self, num_units, activation=None, reuse=None, name=None): - super(SRUCell, self).__init__(_reuse=reuse, name=name) + def __init__(self, num_units, activation=None, reuse=None, name=None, + **kwargs): + super(SRUCell, self).__init__(_reuse=reuse, name=name, **kwargs) self._num_units = num_units self._activation = activation or math_ops.tanh @@ -2777,7 +2791,7 @@ class SRUCell(rnn_cell_impl.LayerRNNCell): self._bias = self.add_variable( rnn_cell_impl._BIAS_VARIABLE_NAME, # pylint: disable=protected-access shape=[2 * self._num_units], - initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) + initializer=init_ops.zeros_initializer) self._built = True diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py index 3fc6bfbb4d03a39906d4441e48b2788423caa234..d8ab9eba7049e468b373a1641f92dc781aa22558 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py @@ -61,10 +61,7 @@ class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase): self._server = server def tearDown(self): - # TODO(ebrevdo): Figure out why this sometimes times out. - # self._service.ExitLoop() - # self._service_thread.join() - # self._server.stop() + self._server.stop(grace=None) super(RpcOpTest, self).tearDown() diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py index 0d615923e04915a8429252317025ac8e79f9bb4e..d6148715be91c78e6e5a99fc0f3caa905b5c1a7d 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py @@ -176,7 +176,9 @@ class RpcOpTestBase(object): expected_message_values = np.where( status_code_values == errors.INVALID_ARGUMENT, I_WARNED_YOU.encode('ascii'), b'') - self.assertAllEqual(expected_message_values, status_message_values) + for msg, expected in zip(status_message_values, expected_message_values): + self.assertTrue(expected in msg, + '"%s" did not contain "%s"' % (msg, expected)) def testVecHostPortRpc(self): with self.cached_session() as sess: diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 269443b2c6508bb618d30f64487b1a6a84e8646f..f0242a3b40fd566ec0f477d462426d5f550d1620 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -84,35 +84,6 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lib", - "//tensorflow/python:metrics", - "//tensorflow/python:platform", - "//tensorflow/python:saver", - "//tensorflow/python:util", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/keras:engine", - "//tensorflow/python/saved_model", - ], -) - -py_test( - name = "keras_saved_model_test", - size = "medium", - srcs = ["python/saved_model/keras_saved_model_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_oss", # TODO(b/119349471): Re-enable - "no_windows", - ], - deps = [ - ":keras_saved_model", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", - "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index ffba514bb96f5ce8d963cb0a0482738eafe88355..0392ed9eee79391c60318faf68d8dfd6eb64a994 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -18,348 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import six +from tensorflow.python.keras import saving -from tensorflow.python.client import session -from tensorflow.python.estimator import keras as estimator_keras_util -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.export import export as export_helpers -from tensorflow.python.framework import ops -from tensorflow.python.keras import backend as K -from tensorflow.python.keras import models as models_lib -from tensorflow.python.keras import optimizers -from tensorflow.python.keras.engine import sequential -from tensorflow.python.keras.metrics import Metric -from tensorflow.python.keras.models import model_from_json -from tensorflow.python.lib.io import file_io -from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.saved_model import builder as saved_model_builder -from tensorflow.python.saved_model import constants -from tensorflow.python.saved_model import utils_impl as saved_model_utils -from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import util as checkpointable_utils -from tensorflow.python.util import compat - -def save_keras_model( - model, saved_model_path, custom_objects=None, as_text=None): - """Save a `tf.keras.Model` into Tensorflow SavedModel format. - - `save_model` generates new files/folders under the `saved_model_path` folder: - 1) an asset folder containing the json string of the model's - configuration (topology). - 2) a checkpoint containing the model weights. - 3) a saved_model.pb file containing the model's MetaGraphs. The prediction - graph is always exported. The evaluaton and training graphs are exported - if the following conditions are met: - - Evaluation: model loss is defined. - - Training: model is compiled with an optimizer defined under `tf.train`. - This is because `tf.keras.optimizers.Optimizer` instances cannot be - saved to checkpoints. - - Model Requirements: - - Model must be a sequential model or functional model. Subclassed models can - not be saved via this function, unless you provide an implementation for - get_config() and from_config(). - - All variables must be saveable by the model. In general, this condition is - met through the use of layers defined in the keras library. However, - there is currently a bug with variables created in Lambda layer functions - not being saved correctly (see - https://github.com/keras-team/keras/issues/9740). - - Note that each mode is exported in separate graphs, so different modes do not - share variables. To use the train graph with evaluation or prediction graphs, - create a new checkpoint if variable values have been updated. - - Example: - - ```python - import tensorflow as tf - - # Create a tf.keras model. - model = tf.keras.Sequential() - model.add(tf.keras.layers.Dense(1, input_shape=[10])) - model.summary() - - # Save the tf.keras model in the SavedModel format. - saved_to_path = tf.contrib.saved_model.save_keras_model( - model, '/tmp/my_simple_tf_keras_saved_model') - - # Load the saved keras model back. - model_prime = tf.contrib.saved_model.load_keras_model(saved_to_path) - model_prime.summary() - ``` - - Args: - model: A `tf.keras.Model` to be saved. - saved_model_path: a string specifying the path to the SavedModel directory. - The SavedModel will be saved to a timestamped folder created within this - directory. - custom_objects: Optional dictionary mapping string names to custom classes - or functions (e.g. custom loss functions). - as_text: whether to write the `SavedModel` proto in text format. - - Returns: - String path to the SavedModel folder, a subdirectory of `saved_model_path`. - - Raises: - NotImplementedError: If the model is a subclassed model. - ValueError: If a Sequential model does not have input shapes defined by the - user, and is not built. - """ - if not model._is_graph_network: - if isinstance(model, sequential.Sequential): - # If input shape is not directly set in the model, the exported model - # will assume that the inputs have the same shape as the shape the model - # was built model with. - if not model.built: - raise ValueError( - 'Sequential model must be built before it can be exported.') - else: - raise NotImplementedError( - 'Exporting subclassed models is not yet supported.') - - export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) - temp_export_dir = export_helpers.get_temp_export_dir(export_dir) - - builder = saved_model_builder._SavedModelBuilder(temp_export_dir) - - # Manually save variables to export them in an object-based checkpoint. This - # skips the `builder.add_meta_graph_and_variables()` step, which saves a - # named-based checkpoint. - # TODO(b/113134168): Add fn to Builder to save with object-based saver. - # TODO(b/113178242): This should only export the model json structure. Only - # one save is needed once the weights can be copied from the model to clone. - checkpoint_path = _export_model_json_and_variables(model, temp_export_dir) - - # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that - # Keras models and `Estimator`s are exported with the same format. - # Every time a mode is exported, the code checks to see if new variables have - # been created (e.g. optimizer slot variables). If that is the case, the - # checkpoint is re-saved to include the new variables. - export_args = {'builder': builder, - 'model': model, - 'custom_objects': custom_objects, - 'checkpoint_path': checkpoint_path} - - has_saved_vars = False - if model.optimizer: - if isinstance(model.optimizer, optimizers.TFOptimizer): - _export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars, **export_args) - has_saved_vars = True - _export_mode(model_fn_lib.ModeKeys.EVAL, has_saved_vars, **export_args) - else: - logging.warning( - 'Model was compiled with an optimizer, but the optimizer is not from ' - '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving ' - 'graph was exported. The train and evaluate graphs were not added to ' - 'the SavedModel.') - _export_mode(model_fn_lib.ModeKeys.PREDICT, has_saved_vars, **export_args) - - builder.save(as_text) - - gfile.Rename(temp_export_dir, export_dir) - return export_dir - - -def _export_model_json_and_variables(model, saved_model_path): - """Save model variables and json structure into SavedModel subdirectories.""" - # Save model configuration as a json string under assets folder. - model_json = model.to_json() - model_json_filepath = os.path.join( - saved_model_utils.get_or_create_assets_dir(saved_model_path), - compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) - file_io.write_string_to_file(model_json_filepath, model_json) - - # Save model weights in checkpoint format under variables folder. - saved_model_utils.get_or_create_variables_dir(saved_model_path) - checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) - model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) - return checkpoint_prefix - - -def _get_var_list(model): - """Return list of all checkpointed saveable objects in the model.""" - return checkpointable_utils.named_saveables(model) - - -def _export_mode( - mode, has_saved_vars, builder, model, custom_objects, checkpoint_path): - """Export a model, and optionally save new vars from the clone model. - - Args: - mode: A `tf.estimator.ModeKeys` string. - has_saved_vars: A `boolean` indicating whether the SavedModel has already - exported variables. - builder: A `SavedModelBuilder` object. - model: A `tf.keras.Model` object. - custom_objects: A dictionary mapping string names to custom classes - or functions. - checkpoint_path: String path to checkpoint. - - Raises: - ValueError: If the train/eval mode is being exported, but the model does - not have an optimizer. - """ - compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT) - if compile_clone and not model.optimizer: - raise ValueError( - 'Model does not have an optimizer. Cannot export mode %s' % mode) - - model_graph = ops.get_default_graph() - with ops.Graph().as_default() as g: - - K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN) - - # Clone the model into blank graph. This will create placeholders for inputs - # and targets. - clone = models_lib.clone_and_build_model( - model, custom_objects=custom_objects, compile_clone=compile_clone) - - # Make sure that iterations variable is added to the global step collection, - # to ensure that, when the SavedModel graph is loaded, the iterations - # variable is returned by `tf.train.get_global_step()`. This is required for - # compatibility with the SavedModelEstimator. - if compile_clone: - g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) - - # Extract update and train ops from train/test/predict functions. - train_op = None - if mode == model_fn_lib.ModeKeys.TRAIN: - clone._make_train_function() - train_op = clone.train_function.updates_op - elif mode == model_fn_lib.ModeKeys.EVAL: - clone._make_test_function() - else: - clone._make_predict_function() - g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates) - - clone_var_list = checkpointable_utils.named_saveables(clone) - - with session.Session().as_default(): - if has_saved_vars: - # Confirm all variables in the clone have an entry in the checkpoint. - status = clone.load_weights(checkpoint_path) - status.assert_existing_objects_matched() - else: - # Confirm that variables between the clone and model match up exactly, - # not counting optimizer objects. Optimizer objects are ignored because - # if the model has not trained, the slot variables will not have been - # created yet. - # TODO(b/113179535): Replace with checkpointable equivalence. - _assert_same_non_optimizer_objects(model, model_graph, clone, g) - - # TODO(b/113178242): Use value transfer for checkpointable objects. - clone.load_weights(checkpoint_path) - - # Add graph and variables to SavedModel. - # TODO(b/113134168): Switch to add_meta_graph_and_variables. - clone.save_weights(checkpoint_path, save_format='tf', overwrite=True) - builder._has_saved_variables = True - - # Add graph to the SavedModel builder. - builder.add_meta_graph( - model_fn_lib.EXPORT_TAG_MAP[mode], - signature_def_map=_create_signature_def_map(clone, mode), - saver=saver_lib.Saver(clone_var_list), - init_op=variables.local_variables_initializer(), - train_op=train_op) - return None - - -def _create_signature_def_map(model, mode): - """Create a SignatureDef map from a Keras model.""" - inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)} - if model.optimizer: - targets_dict = {x.name.split(':')[0]: x - for x in model.targets if x is not None} - inputs_dict.update(targets_dict) - outputs_dict = {name: x - for name, x in zip(model.output_names, model.outputs)} - metrics = estimator_keras_util._convert_keras_metrics_to_estimator(model) - - # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables - # are by default not added to any collections. We are doing this here, so - # that metric variables get initialized. - local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) - vars_to_add = set() - if metrics is not None: - for key, value in six.iteritems(metrics): - if isinstance(value, Metric): - vars_to_add.update(value.variables) - # Convert Metric instances to (value_tensor, update_op) tuple. - metrics[key] = (value.result(), value.updates[0]) - # Remove variables that are in the local variables collection already. - vars_to_add = vars_to_add.difference(local_vars) - for v in vars_to_add: - ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v) - - export_outputs = model_fn_lib.export_outputs_for_mode( - mode, - predictions=outputs_dict, - loss=model.total_loss if model.optimizer else None, - metrics=metrics) - return export_helpers.build_all_signature_defs( - inputs_dict, - export_outputs=export_outputs, - serving_only=(mode == model_fn_lib.ModeKeys.PREDICT)) - - -def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument - """Assert model and clone contain the same checkpointable objects.""" - - # TODO(fchollet, kathywu): make sure this works in eager mode. - return True - - -def load_keras_model(saved_model_path): - """Load a keras.Model from SavedModel. - - load_model reinstantiates model state by: - 1) loading model topology from json (this will eventually come - from metagraph). - 2) loading model weights from checkpoint. - - Example: - - ```python - import tensorflow as tf - - # Create a tf.keras model. - model = tf.keras.Sequential() - model.add(tf.keras.layers.Dense(1, input_shape=[10])) - model.summary() - - # Save the tf.keras model in the SavedModel format. - saved_to_path = tf.contrib.saved_model.save_keras_model( - model, '/tmp/my_simple_tf_keras_saved_model') - - # Load the saved keras model back. - model_prime = tf.contrib.saved_model.load_keras_model(saved_to_path) - model_prime.summary() - ``` - - Args: - saved_model_path: a string specifying the path to an existing SavedModel. - - Returns: - a keras.Model instance. - """ - # restore model topology from json string - model_json_filepath = os.path.join( - compat.as_bytes(saved_model_path), - compat.as_bytes(constants.ASSETS_DIRECTORY), - compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) - model_json = file_io.read_file_to_string(model_json_filepath) - model = model_from_json(model_json) - - # restore model weights - checkpoint_prefix = os.path.join( - compat.as_text(saved_model_path), - compat.as_text(constants.VARIABLES_DIRECTORY), - compat.as_text(constants.VARIABLES_FILENAME)) - model.load_weights(checkpoint_prefix) - return model +# TODO(kathywu): Remove all contrib callers, switch to tf.keras. +save_keras_model = saving.export +load_keras_model = saving.load_from_saved_model diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 18b56cd21942e28cb0dc3210df0bb04d55c1e16f..2a70b08f5c46e11e7fd83fe134741b9a241153f5 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -33,7 +33,6 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":beam_search_ops", - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/contrib/util:util_py", @@ -59,7 +58,6 @@ tf_custom_op_py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/distributions", "//third_party/py/numpy", "@six_archive//:six", ], @@ -141,6 +139,27 @@ cuda_py_test( ], ) +cuda_py_test( + name = "basic_decoder_v2_test", + size = "medium", + srcs = ["python/kernel_tests/basic_decoder_v2_test.py"], + additional_deps = [ + ":seq2seq_py", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:rnn", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "beam_search_ops_test", size = "medium", @@ -175,6 +194,27 @@ cuda_py_test( ], ) +cuda_py_test( + name = "decoder_v2_test", + size = "medium", + srcs = ["python/kernel_tests/decoder_v2_test.py"], + additional_deps = [ + ":seq2seq_py", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:rnn", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "beam_search_decoder_test", size = "medium", @@ -215,3 +255,18 @@ cuda_py_test( "//tensorflow/python:variables", ], ) + +cuda_py_test( + name = "attention_wrapper_v2_test", + size = "medium", + srcs = ["python/kernel_tests/attention_wrapper_v2_test.py"], + additional_deps = [ + ":seq2seq_py", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 922f21b98b35dfff19c8c605a25e89c5d2da8d98..d815f81f847ad79ddcc6c6ecf5c050598e185d8d 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope as vs @@ -992,5 +993,67 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=expected_final_alignment_history, name='testMultiAttention') + def testCustomizedAttention(self): + batch_size = 2 + max_time = 3 + num_units = 2 + memory = constant_op.constant([[[1., 1.], [2., 2.], [3., 3.]], + [[4., 4.], [5., 5.], [6., 6.]]]) + memory_sequence_length = constant_op.constant([3, 2]) + attention_mechanism = wrapper.BahdanauAttention(num_units, memory, + memory_sequence_length) + + # Sets all returned values to be all ones. + def _customized_attention(unused_attention_mechanism, unused_cell_output, + unused_attention_state, unused_attention_layer): + """Customized attention. + + Returns: + attention: `Tensor` of shape [batch_size, num_units], attention output. + alignments: `Tensor` of shape [batch_size, max_time], sigma value for + each input memory (prob. function of input keys). + next_attention_state: A `Tensor` representing the next state for the + attention. + """ + attention = array_ops.ones([batch_size, num_units]) + alignments = array_ops.ones([batch_size, max_time]) + next_attention_state = alignments + return attention, alignments, next_attention_state + + attention_cell = wrapper.AttentionWrapper( + rnn_cell.LSTMCell(2), + attention_mechanism, + attention_layer_size=None, # don't use attention layer. + output_attention=False, + alignment_history=(), + attention_fn=_customized_attention, + name='attention') + self.assertEqual(num_units, attention_cell.output_size) + + initial_state = attention_cell.zero_state( + batch_size=2, dtype=dtypes.float32) + source_input_emb = array_ops.ones([2, 3, 2]) + source_input_length = constant_op.constant([3, 2]) + + # 'state' is a tuple of + # (cell_state, h, attention, alignments, alignment_history, attention_state) + output, state = rnn.dynamic_rnn( + attention_cell, + inputs=source_input_emb, + sequence_length=source_input_length, + initial_state=initial_state, + dtype=dtypes.float32) + + with self.session() as sess: + sess.run(variables.global_variables_initializer()) + output_value, state_value = sess.run([output, state], feed_dict={}) + self.assertAllEqual(np.array([2, 3, 2]), output_value.shape) + self.assertAllClose(np.array([[1., 1.], [1., 1.]]), state_value.attention) + self.assertAllClose( + np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.alignments) + self.assertAllClose( + np.array([[1., 1., 1.], [1., 1., 1.]]), state_value.attention_state) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff04e1780c4c44df14d6e87c5afdbf533ca5c90 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py @@ -0,0 +1,94 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.ops.attention_wrapper.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +@test_util.run_all_in_graph_and_eager_modes +class AttentionMechanismTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super(AttentionMechanismTest, self).setUp() + self.batch = 10 + self.timestep = 5 + self.memory_size = 6 + self.units = 8 + + self.memory = ops.convert_to_tensor( + np.random.random((self.batch, self.timestep, self.memory_size)), + dtype=np.float32) + self.query = ops.convert_to_tensor( + np.random.random((self.batch, self.units)), dtype=np.float32) + self.state = ops.convert_to_tensor( + np.random.random((self.batch, self.timestep)), dtype=np.float32) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_attention_shape_inference(self, attention_cls): + attention = attention_cls(self.units) + attention_score = attention([self.query, self.state, self.memory]) + self.assertLen(attention_score, 2) + self.assertEqual(attention_score[0].shape, (self.batch, self.timestep)) + self.assertEqual(attention_score[1].shape, (self.batch, self.timestep)) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_get_config(self, attention_cls): + attention = attention_cls(self.units) + config = attention.get_config() + + attention_from_config = attention_cls.from_config(config) + config_from_clone = attention_from_config.get_config() + + self.assertDictEqual(config, config_from_clone) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_layer_output(self, attention_cls): + attention = attention_cls(self.units) + + score = attention([self.query, self.state, self.memory]) + self.evaluate(variables.variables_initializer(attention.variables)) + + score_val = self.evaluate(score) + self.assertLen(score_val, 2) + self.assertEqual(score_val[0].shape, (self.batch, self.timestep)) + self.assertEqual(score_val[1].shape, (self.batch, self.timestep)) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index b7f9f3fb090356a1c8d2bfb5044712ff93e267ce..abcf71c61b6e6df9462bf06323b8b11d5cc0d9a8 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -34,8 +34,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope -from tensorflow.python.ops.distributions import bernoulli -from tensorflow.python.ops.distributions import categorical from tensorflow.python.platform import test # pylint: enable=g-import-not-at-top @@ -517,7 +515,7 @@ class BasicDecoderTest(test.TestCase): vocabulary_size) # The sample function samples categorically from the logits. - sample_fn = lambda x: categorical.Categorical(logits=x).sample() + sample_fn = lambda x: helper_py.categorical_sample(logits=x) # The next inputs are a one-hot encoding of the sampled labels. next_inputs_fn = ( lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32)) @@ -599,7 +597,7 @@ class BasicDecoderTest(test.TestCase): # The sample function samples independent bernoullis from the logits. sample_fn = ( - lambda x: bernoulli.Bernoulli(logits=x, dtype=dtypes.bool).sample()) + lambda x: helper_py.bernoulli_sample(logits=x, dtype=dtypes.bool)) # The next inputs are a one-hot encoding of the sampled labels. next_inputs_fn = math_ops.to_float end_fn = lambda sample_ids: sample_ids[:, end_token] diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..15fb688fc4dd4909e5bab36def7ac58e9d7be4ea --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py @@ -0,0 +1,670 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.seq2seq.basic_decoder_v2.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.layers import core as layers_core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +@keras_parameterized.run_all_keras_modes +class BasicDecoderTest(keras_parameterized.TestCase): + """Unit test for basic_decoder.BasicDecoderV2.""" + + @parameterized.named_parameters( + ("use_output_layer", True), + ("without_output_layer", False)) + def testStepWithTrainingHelperOutputLayer(self, use_output_layer): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = 10 + output_layer_depth = 3 + + with self.cached_session(use_gpu=True): + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + cell = rnn_cell.LSTMCell(cell_depth) + sampler = sampler_py.TrainingSampler(time_major=False) + if use_output_layer: + output_layer = layers_core.Dense(output_layer_depth, use_bias=False) + expected_output_depth = output_layer_depth + else: + output_layer = None + expected_output_depth = cell_depth + initial_state = cell.zero_state(dtype=dtypes.float32, + batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler, + output_layer=output_layer) + + (first_finished, + first_inputs, + first_state) = my_decoder.initialize(input_t, + initial_state=initial_state, + sequence_length=sequence_length) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(expected_output_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, expected_output_depth), + step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + if use_output_layer: + # The output layer was accessed + self.assertEqual(len(output_layer.variables), 1) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + self.assertAllEqual([False, False, False, False, True], + eval_result["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + eval_result["step_finished"]) + self.assertEqual(output_dtype.sample_id, + eval_result["step_outputs"].sample_id.dtype) + self.assertAllEqual( + np.argmax(eval_result["step_outputs"].rnn_output, -1), + eval_result["step_outputs"].sample_id) + + def testStepWithGreedyEmbeddingHelper(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size # cell's logits must match vocabulary size + input_depth = 10 + start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) + end_token = 1 + + with self.cached_session(use_gpu=True): + embeddings = np.random.randn(vocabulary_size, + input_depth).astype(np.float32) + embeddings_t = constant_op.constant(embeddings) + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.GreedyEmbeddingSampler() + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + embeddings_t, + start_tokens=start_tokens, + end_token=end_token, + initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + expected_sample_ids = np.argmax( + eval_result["step_outputs"].rnn_output, -1) + expected_step_finished = (expected_sample_ids == end_token) + expected_step_next_inputs = embeddings[expected_sample_ids] + self.assertAllEqual([False, False, False, False, False], + eval_result["first_finished"]) + self.assertAllEqual(expected_step_finished, eval_result["step_finished"]) + self.assertEqual(output_dtype.sample_id, + eval_result["step_outputs"].sample_id.dtype) + self.assertAllEqual(expected_sample_ids, + eval_result["step_outputs"].sample_id) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + + def testStepWithSampleEmbeddingHelper(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size # cell's logits must match vocabulary size + input_depth = 10 + np.random.seed(0) + start_tokens = np.random.randint(0, vocabulary_size, size=batch_size) + end_token = 1 + + with self.cached_session(use_gpu=True): + embeddings = np.random.randn(vocabulary_size, + input_depth).astype(np.float32) + embeddings_t = constant_op.constant(embeddings) + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.SampleEmbeddingSampler(seed=0) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + (first_finished, + first_inputs, + first_state) = my_decoder.initialize(embeddings_t, + start_tokens=start_tokens, + end_token=end_token, + initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + expected_step_finished = (sample_ids == end_token) + expected_step_next_inputs = embeddings[sample_ids] + self.assertAllEqual(expected_step_finished, + eval_result["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + + def testStepWithScheduledEmbeddingTrainingHelper(self): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + vocabulary_size = 10 + + with self.cached_session(use_gpu=True): + inputs = np.random.randn( + batch_size, max_time, input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + embeddings = np.random.randn( + vocabulary_size, input_depth).astype(np.float32) + half = constant_op.constant(0.5) + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.ScheduledEmbeddingTrainingSampler( + sampling_probability=half, + time_major=False) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + input_t, sequence_length=sequence_length, embedding=embeddings, + initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(vocabulary_size, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, vocabulary_size), + step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + first_state[0].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + first_state[1].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + step_state[0].get_shape()) + self.assertEqual((batch_size, vocabulary_size), + step_state[1].get_shape()) + self.assertEqual((batch_size, input_depth), + step_next_inputs.get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + self.assertAllEqual([False, False, False, False, True], + eval_result["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + eval_result["step_finished"]) + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + batch_where_not_sampling = np.where(sample_ids == -1) + batch_where_sampling = np.where(sample_ids > -1) + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_sampling], + embeddings[sample_ids[batch_where_sampling]]) + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_not_sampling], + np.squeeze(inputs[batch_where_not_sampling, 1], axis=0)) + + def _testStepWithScheduledOutputTrainingHelper( + self, sampling_probability, use_next_inputs_fn, use_auxiliary_inputs): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = input_depth + if use_auxiliary_inputs: + auxiliary_input_depth = 4 + auxiliary_inputs = np.random.randn( + batch_size, max_time, auxiliary_input_depth).astype(np.float32) + else: + auxiliary_inputs = None + + with self.cached_session(use_gpu=True): + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + cell = rnn_cell.LSTMCell(cell_depth) + sampling_probability = constant_op.constant(sampling_probability) + + if use_next_inputs_fn: + def next_inputs_fn(outputs): + # Use deterministic function for test. + samples = math_ops.argmax(outputs, axis=1) + return array_ops.one_hot(samples, cell_depth, dtype=dtypes.float32) + else: + next_inputs_fn = None + + sampler = sampler_py.ScheduledOutputTrainingSampler( + sampling_probability=sampling_probability, + time_major=False, + next_inputs_fn=next_inputs_fn) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + + (first_finished, + first_inputs, + first_state) = my_decoder.initialize(input_t, + sequence_length=sequence_length, + initial_state=initial_state, + auxiliary_inputs=auxiliary_inputs) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + + if use_next_inputs_fn: + output_after_next_inputs_fn = next_inputs_fn(step_outputs.rnn_output) + + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + + fetches = { + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + } + if use_next_inputs_fn: + fetches["output_after_next_inputs_fn"] = output_after_next_inputs_fn + + eval_result = self.evaluate(fetches) + + self.assertAllEqual([False, False, False, False, True], + eval_result["first_finished"]) + self.assertAllEqual([False, False, False, True, True], + eval_result["step_finished"]) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + batch_where_not_sampling = np.where(np.logical_not(sample_ids)) + batch_where_sampling = np.where(sample_ids) + + auxiliary_inputs_to_concat = ( + auxiliary_inputs[:, 1] if use_auxiliary_inputs else + np.array([]).reshape(batch_size, 0).astype(np.float32)) + + expected_next_sampling_inputs = np.concatenate( + (eval_result["output_after_next_inputs_fn"][batch_where_sampling] + if use_next_inputs_fn else + eval_result["step_outputs"].rnn_output[batch_where_sampling], + auxiliary_inputs_to_concat[batch_where_sampling]), + axis=-1) + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_sampling], + expected_next_sampling_inputs) + + self.assertAllClose( + eval_result["step_next_inputs"][batch_where_not_sampling], + np.concatenate( + (np.squeeze(inputs[batch_where_not_sampling, 1], axis=0), + auxiliary_inputs_to_concat[batch_where_not_sampling]), + axis=-1)) + + def testStepWithScheduledOutputTrainingHelperWithoutNextInputsFnOrAuxInputs( + self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=False, + use_auxiliary_inputs=False) + + def testStepWithScheduledOutputTrainingHelperWithNextInputsFn(self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=True, + use_auxiliary_inputs=False) + + def testStepWithScheduledOutputTrainingHelperWithAuxiliaryInputs(self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=False, + use_auxiliary_inputs=True) + + def testStepWithScheduledOutputTrainingHelperWithNextInputsFnAndAuxInputs( + self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.5, use_next_inputs_fn=True, + use_auxiliary_inputs=True) + + def testStepWithScheduledOutputTrainingHelperWithNoSampling(self): + self._testStepWithScheduledOutputTrainingHelper( + sampling_probability=0.0, use_next_inputs_fn=True, + use_auxiliary_inputs=True) + + def testStepWithInferenceHelperCategorical(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size + start_token = 0 + end_token = 6 + + start_inputs = array_ops.one_hot( + np.ones(batch_size, dtype=np.int32) * start_token, + vocabulary_size) + + # The sample function samples categorically from the logits. + sample_fn = lambda x: sampler_py.categorical_sample(logits=x) + # The next inputs are a one-hot encoding of the sampled labels. + next_inputs_fn = ( + lambda x: array_ops.one_hot(x, vocabulary_size, dtype=dtypes.float32)) + end_fn = lambda sample_ids: math_ops.equal(sample_ids, end_token) + + with self.cached_session(use_gpu=True): + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.InferenceSampler( + sample_fn, sample_shape=(), sample_dtype=dtypes.int32, end_fn=end_fn, + next_inputs_fn=next_inputs_fn) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + start_inputs, initial_state=initial_state) + + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, + tensor_shape.TensorShape([])), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.int32), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size,), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + expected_step_finished = (sample_ids == end_token) + expected_step_next_inputs = np.zeros((batch_size, vocabulary_size)) + expected_step_next_inputs[np.arange(batch_size), sample_ids] = 1.0 + self.assertAllEqual(expected_step_finished, + eval_result["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + + def testStepWithInferenceHelperMultilabel(self): + batch_size = 5 + vocabulary_size = 7 + cell_depth = vocabulary_size + start_token = 0 + end_token = 6 + + start_inputs = array_ops.one_hot( + np.ones(batch_size, dtype=np.int32) * start_token, + vocabulary_size) + + # The sample function samples independent bernoullis from the logits. + sample_fn = ( + lambda x: sampler_py.bernoulli_sample(logits=x, dtype=dtypes.bool)) + # The next inputs are a one-hot encoding of the sampled labels. + next_inputs_fn = math_ops.to_float + end_fn = lambda sample_ids: sample_ids[:, end_token] + + with self.cached_session(use_gpu=True): + cell = rnn_cell.LSTMCell(vocabulary_size) + sampler = sampler_py.InferenceSampler( + sample_fn, sample_shape=[cell_depth], sample_dtype=dtypes.bool, + end_fn=end_fn, next_inputs_fn=next_inputs_fn) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler) + (first_finished, first_inputs, first_state) = my_decoder.initialize( + start_inputs, initial_state=initial_state) + output_size = my_decoder.output_size + output_dtype = my_decoder.output_dtype + self.assertEqual( + basic_decoder.BasicDecoderOutput(cell_depth, cell_depth), + output_size) + self.assertEqual( + basic_decoder.BasicDecoderOutput(dtypes.float32, dtypes.bool), + output_dtype) + + (step_outputs, step_state, step_next_inputs, + step_finished) = my_decoder.step( + constant_op.constant(0), first_inputs, first_state) + batch_size_t = my_decoder.batch_size + + self.assertTrue(isinstance(first_state, rnn_cell.LSTMStateTuple)) + self.assertTrue(isinstance(step_state, rnn_cell.LSTMStateTuple)) + self.assertTrue( + isinstance(step_outputs, basic_decoder.BasicDecoderOutput)) + self.assertEqual((batch_size, cell_depth), step_outputs[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_outputs[1].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), first_state[1].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[0].get_shape()) + self.assertEqual((batch_size, cell_depth), step_state[1].get_shape()) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "batch_size": batch_size_t, + "first_finished": first_finished, + "first_inputs": first_inputs, + "first_state": first_state, + "step_outputs": step_outputs, + "step_state": step_state, + "step_next_inputs": step_next_inputs, + "step_finished": step_finished + }) + + sample_ids = eval_result["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) + expected_step_finished = sample_ids[:, end_token] + expected_step_next_inputs = sample_ids.astype(np.float32) + self.assertAllEqual(expected_step_finished, + eval_result["step_finished"]) + self.assertAllEqual(expected_step_next_inputs, + eval_result["step_next_inputs"]) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5bba2b32e940aa4d5984821ebd3845d7f272549 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_v2_test.py @@ -0,0 +1,169 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.seq2seq.python.seq2seq.decoder.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +@keras_parameterized.run_all_keras_modes +class DecodeV2RNNTest(keras_parameterized.TestCase, test.TestCase): + """Tests for DecoderV2.""" + + def _testDecodeRNN(self, time_major, maximum_iterations=None): + + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = 10 + max_out = max(sequence_length) + + with self.cached_session(use_gpu=True): + if time_major: + inputs = np.random.randn(max_time, batch_size, + input_depth).astype(np.float32) + else: + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + input_t = constant_op.constant(inputs) + cell = rnn_cell.LSTMCell(cell_depth) + sampler = sampler_py.TrainingSampler(time_major=time_major) + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, + sampler=sampler, + output_time_major=time_major, + maximum_iterations=maximum_iterations) + + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + (final_outputs, unused_final_state, final_sequence_length) = my_decoder( + input_t, initial_state=initial_state, sequence_length=sequence_length) + + def _t(shape): + if time_major: + return (shape[1], shape[0]) + shape[2:] + return shape + + if not context.executing_eagerly(): + self.assertEqual((batch_size,), + tuple(final_sequence_length.get_shape().as_list())) + self.assertEqual( + _t((batch_size, None, cell_depth)), + tuple(final_outputs.rnn_output.get_shape().as_list())) + self.assertEqual( + _t((batch_size, None)), + tuple(final_outputs.sample_id.get_shape().as_list())) + + self.evaluate(variables.global_variables_initializer()) + final_outputs = self.evaluate(final_outputs) + final_sequence_length = self.evaluate(final_sequence_length) + + # Mostly a smoke test + time_steps = max_out + expected_length = sequence_length + if maximum_iterations is not None: + time_steps = min(max_out, maximum_iterations) + expected_length = [min(x, maximum_iterations) for x in expected_length] + if context.executing_eagerly() and maximum_iterations != 0: + # Only check the shape of output when maximum_iterations > 0, see + # b/123431432 for more details. + self.assertEqual( + _t((batch_size, time_steps, cell_depth)), + final_outputs.rnn_output.shape) + self.assertEqual( + _t((batch_size, time_steps)), final_outputs.sample_id.shape) + self.assertItemsEqual(expected_length, final_sequence_length) + + def testDynamicDecodeRNNBatchMajor(self): + self._testDecodeRNN(time_major=False) + + def testDynamicDecodeRNNTimeMajor(self): + self._testDecodeRNN(time_major=True) + + def testDynamicDecodeRNNZeroMaxIters(self): + self._testDecodeRNN(time_major=True, maximum_iterations=0) + + def testDynamicDecodeRNNOneMaxIter(self): + self._testDecodeRNN(time_major=True, maximum_iterations=1) + + def _testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( + self, use_sequence_length): + sequence_length = [3, 4, 3, 1, 0] + batch_size = 5 + max_time = 8 + input_depth = 7 + cell_depth = 10 + max_out = max(sequence_length) + + with self.cached_session(use_gpu=True): + inputs = np.random.randn(batch_size, max_time, + input_depth).astype(np.float32) + inputs = constant_op.constant(inputs) + + cell = rnn_cell.LSTMCell(cell_depth) + zero_state = cell.zero_state(dtype=dtypes.float32, batch_size=batch_size) + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2( + cell=cell, sampler=sampler, impute_finished=use_sequence_length) + + final_decoder_outputs, final_decoder_state, _ = my_decoder( + inputs, initial_state=zero_state, sequence_length=sequence_length) + + final_rnn_outputs, final_rnn_state = rnn.dynamic_rnn( + cell, + inputs, + sequence_length=sequence_length if use_sequence_length else None, + initial_state=zero_state) + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "final_decoder_outputs": final_decoder_outputs, + "final_decoder_state": final_decoder_state, + "final_rnn_outputs": final_rnn_outputs, + "final_rnn_state": final_rnn_state + }) + + # Decoder only runs out to max_out; ensure values are identical + # to dynamic_rnn, which also zeros out outputs and passes along state. + self.assertAllClose(eval_result["final_decoder_outputs"].rnn_output, + eval_result["final_rnn_outputs"][:, 0:max_out, :]) + if use_sequence_length: + self.assertAllClose(eval_result["final_decoder_state"], + eval_result["final_rnn_state"]) + + def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNWithSeqLen(self): + self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( + use_sequence_length=True) + + def testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNNNoSeqLen(self): + self._testDynamicDecodeRNNWithTrainingHelperMatchesDynamicRNN( + use_sequence_length=False) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py index 5aa32b532ffcf5772f6ace26662f5e5471cf6923..41b2a53ca5b178be9b04446c81d832575e5ed75b 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/loss_test.py @@ -14,80 +14,254 @@ # ============================================================================== """Tests for contrib.seq2seq.python.seq2seq.loss_ops.""" -# pylint: disable=unused-import,g-bad-import-order from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: enable=unused-import import numpy as np from tensorflow.contrib.seq2seq.python.ops import loss from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class LossTest(test.TestCase): + def setUp(self): + self.batch_size = 2 + self.sequence_length = 3 + self.number_of_classes = 5 + logits = [ + constant_op.constant(i + 0.5, shape=[self.batch_size, + self.number_of_classes]) + for i in range(self.sequence_length) + ] + self.logits = array_ops.stack(logits, axis=1) + targets = [ + constant_op.constant(i, dtypes.int32, shape=[self.batch_size]) + for i in range(self.sequence_length) + ] + self.targets = array_ops.stack(targets, axis=1) + weights = [ + constant_op.constant(1.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + self.weights = array_ops.stack(weights, axis=1) + # expected_loss = sparse_softmax_cross_entropy_with_logits(targets, logits) + # where targets = [0, 1, 2], and logits = [[0.5] * 5, [1.5] * 5, [2.5] * 5] + self.expected_loss = 1.60944 + def testSequenceLoss(self): - with self.session(use_gpu=True) as sess: - with variable_scope.variable_scope( - 'root', initializer=init_ops.constant_initializer(0.5)): - batch_size = 2 - sequence_length = 3 - number_of_classes = 5 - logits = [ - constant_op.constant( - i + 0.5, shape=[batch_size, number_of_classes]) - for i in range(sequence_length) - ] - logits = array_ops.stack(logits, axis=1) - targets = [ - constant_op.constant( - i, dtypes.int32, shape=[batch_size]) - for i in range(sequence_length) - ] - targets = array_ops.stack(targets, axis=1) - weights = [ - constant_op.constant( - 1.0, shape=[batch_size]) for i in range(sequence_length) - ] - weights = array_ops.stack(weights, axis=1) - - average_loss_per_example = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=True, - average_across_batch=True) - res = sess.run(average_loss_per_example) - self.assertAllClose(1.60944, res) - - average_loss_per_sequence = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=False, - average_across_batch=True) - res = sess.run(average_loss_per_sequence) - compare_per_sequence = np.ones((sequence_length)) * 1.60944 - self.assertAllClose(compare_per_sequence, res) - - average_loss_per_batch = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=True, - average_across_batch=False) - res = sess.run(average_loss_per_batch) - compare_per_batch = np.ones((batch_size)) * 1.60944 - self.assertAllClose(compare_per_batch, res) - - total_loss = loss.sequence_loss( - logits, targets, weights, - average_across_timesteps=False, - average_across_batch=False) - res = sess.run(total_loss) - compare_total = np.ones((batch_size, sequence_length)) * 1.60944 - self.assertAllClose(compare_total, res) + with self.test_session(use_gpu=True): + average_loss_per_example = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=True, + average_across_batch=True) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + average_loss_per_sequence = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=False, + average_across_batch=True) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + average_loss_per_batch = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=True, + average_across_batch=False) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + total_loss = loss.sequence_loss( + self.logits, self.targets, self.weights, + average_across_timesteps=False, + average_across_batch=False) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testSequenceLossClass(self): + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=True, + average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=True, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + average_loss_per_batch = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testSumReduction(self): + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=True) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=False) + average_loss_per_batch = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + self.assertAllClose(compare_total, res) + + def testWeightedSumReduction(self): + weights = [ + constant_op.constant(1.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + # Make the last element in the sequence to have zero weights. + weights[-1] = constant_op.constant(0.0, shape=[self.batch_size]) + self.weights = array_ops.stack(weights, axis=1) + with self.test_session(use_gpu=True): + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True) + average_loss_per_example = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(self.expected_loss, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=True) + average_loss_per_sequence = seq_loss( + self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.full((self.sequence_length), self.expected_loss) + # The last element in every sequence are zeros, which will be filtered. + compare_per_sequence[-1] = 0. + self.assertAllClose(compare_per_sequence, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=False) + average_loss_per_batch = seq_loss(self.targets, self.logits, self.weights) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.full((self.batch_size), self.expected_loss) + self.assertAllClose(compare_per_batch, res) + + seq_loss = loss.SequenceLoss(average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=False, + sum_over_batch=False) + total_loss = seq_loss(self.targets, self.logits, self.weights) + res = self.evaluate(total_loss) + compare_total = np.full((self.batch_size, self.sequence_length), + self.expected_loss) + # The last element in every sequence are zeros, which will be filtered. + compare_total[:, -1] = 0 + self.assertAllClose(compare_total, res) + + def testZeroWeights(self): + weights = [ + constant_op.constant(0.0, shape=[self.batch_size]) + for _ in range(self.sequence_length) + ] + weights = array_ops.stack(weights, axis=1) + with self.test_session(use_gpu=True): + average_loss_per_example = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=True, + average_across_batch=True) + res = self.evaluate(average_loss_per_example) + self.assertAllClose(0.0, res) + + average_loss_per_sequence = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=False, + average_across_batch=True) + res = self.evaluate(average_loss_per_sequence) + compare_per_sequence = np.zeros((self.sequence_length)) + self.assertAllClose(compare_per_sequence, res) + + average_loss_per_batch = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=True, + average_across_batch=False) + res = self.evaluate(average_loss_per_batch) + compare_per_batch = np.zeros((self.batch_size)) + self.assertAllClose(compare_per_batch, res) + + total_loss = loss.sequence_loss( + self.logits, self.targets, weights, + average_across_timesteps=False, + average_across_batch=False) + res = self.evaluate(total_loss) + compare_total = np.zeros((self.batch_size, self.sequence_length)) + self.assertAllClose(compare_total, res) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 77e9f848b137911b53e1b4df5dd740fe38af55bb..ae3e7f1b5d8c9f06b5defbaee9cad3810e58abd4 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -28,6 +28,7 @@ from tensorflow.contrib.framework.python.framework import tensor_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import layers from tensorflow.python.layers import base as layers_base from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops @@ -72,77 +73,6 @@ class AttentionMechanism(object): raise NotImplementedError -def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): - """Convert to tensor and possibly mask `memory`. - - Args: - memory: `Tensor`, shaped `[batch_size, max_time, ...]`. - memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. - check_inner_dims_defined: Python boolean. If `True`, the `memory` - argument's shape is checked to ensure all but the two outermost - dimensions are fully defined. - - Returns: - A (possibly masked), checked, new `memory`. - - Raises: - ValueError: If `check_inner_dims_defined` is `True` and not - `memory.shape[2:].is_fully_defined()`. - """ - memory = nest.map_structure( - lambda m: ops.convert_to_tensor(m, name="memory"), memory) - if memory_sequence_length is not None: - memory_sequence_length = ops.convert_to_tensor( - memory_sequence_length, name="memory_sequence_length") - if check_inner_dims_defined: - def _check_dims(m): - if not m.get_shape()[2:].is_fully_defined(): - raise ValueError("Expected memory %s to have fully defined inner dims, " - "but saw shape: %s" % (m.name, m.get_shape())) - nest.map_structure(_check_dims, memory) - if memory_sequence_length is None: - seq_len_mask = None - else: - seq_len_mask = array_ops.sequence_mask( - memory_sequence_length, - maxlen=array_ops.shape(nest.flatten(memory)[0])[1], - dtype=nest.flatten(memory)[0].dtype) - seq_len_batch_size = ( - tensor_shape.dimension_value(memory_sequence_length.shape[0]) - or array_ops.shape(memory_sequence_length)[0]) - def _maybe_mask(m, seq_len_mask): - rank = m.get_shape().ndims - rank = rank if rank is not None else array_ops.rank(m) - extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) - m_batch_size = tensor_shape.dimension_value( - m.shape[0]) or array_ops.shape(m)[0] - if memory_sequence_length is not None: - message = ("memory_sequence_length and memory tensor batch sizes do not " - "match.") - with ops.control_dependencies([ - check_ops.assert_equal( - seq_len_batch_size, m_batch_size, message=message)]): - seq_len_mask = array_ops.reshape( - seq_len_mask, - array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) - return m * seq_len_mask - else: - return m - return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) - - -def _maybe_mask_score(score, memory_sequence_length, score_mask_value): - if memory_sequence_length is None: - return score - message = ("All values in memory_sequence_length must greater than zero.") - with ops.control_dependencies( - [check_ops.assert_positive(memory_sequence_length, message=message)]): - score_mask = array_ops.sequence_mask( - memory_sequence_length, maxlen=array_ops.shape(score)[1]) - score_mask_values = score_mask_value * array_ops.ones_like(score) - return array_ops.where(score_mask, score, score_mask_values) - - class _BaseAttentionMechanism(AttentionMechanism): """A base AttentionMechanism class providing common functionality. @@ -205,12 +135,14 @@ class _BaseAttentionMechanism(AttentionMechanism): self._memory_layer.dtype).as_numpy_dtype(-np.inf) self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda probability_fn( - _maybe_mask_score(score, memory_sequence_length, score_mask_value), + _maybe_mask_score(score, + memory_sequence_length=memory_sequence_length, + score_mask_value=score_mask_value), prev)) with ops.name_scope( name, "BaseAttentionMechanismInit", nest.flatten(memory)): self._values = _prepare_memory( - memory, memory_sequence_length, + memory, memory_sequence_length=memory_sequence_length, check_inner_dims_defined=check_inner_dims_defined) self._keys = ( self.memory_layer(self._values) if self.memory_layer # pylint: disable=not-callable @@ -286,6 +218,207 @@ class _BaseAttentionMechanism(AttentionMechanism): return self.initial_alignments(batch_size, dtype) +class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): + """A base AttentionMechanism class providing common functionality. + + Common functionality includes: + 1. Storing the query and memory layers. + 2. Preprocessing and storing the memory. + + Note that this layer only support Keras functional API since it takes multiple + input tensors, which is not available in sequential model. + """ + + def __init__(self, + probability_fn, + query_layer=None, + memory_layer=None, + **kwargs): + """Construct base AttentionMechanism class. + + Args: + probability_fn: A `callable`. Converts the score and previous alignments + to probabilities. Its signature should be: + `probabilities = probability_fn(score, state)`. + query_layer: (optional): Instance of `tf.keras.Layer`. The layer's depth + must match the depth of `memory_layer`. If `query_layer` is not + provided, the shape of `query` must match that of `memory_layer`. + memory_layer: (optional): Instance of `tf.keras.Layer`. The layer's + depth must match the depth of `query_layer`. + If `memory_layer` is not provided, the shape of `memory` must match + that of `query_layer`. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + if (query_layer is not None + and not isinstance(query_layer, layers.Layer)): + raise TypeError( + "query_layer is not a Layer: %s" % type(query_layer).__name__) + if (memory_layer is not None + and not isinstance(memory_layer, layers.Layer)): + raise TypeError( + "memory_layer is not a Layer: %s" % type(memory_layer).__name__) + self.query_layer = query_layer + self.memory_layer = memory_layer + if self.memory_layer is not None and "dtype" not in kwargs: + kwargs["dtype"] = self.memory_layer.dtype + super(_BaseAttentionMechanismV2, self).__init__(**kwargs) + if not callable(probability_fn): + raise TypeError("probability_fn must be callable, saw type: %s" % + type(probability_fn).__name__) + self.probability_fn = probability_fn + + self.keys = None + self.values = None + self.batch_size = None + self._memory_initialized = False + self._check_inner_dims_defined = True + + def build(self, input_shape): + # The layer suppose to take 3 inputs, [query, state, memory]. + query_input_shape, _, memory_input_shape = input_shape + if self.query_layer is not None: + self.query_layer.build(query_input_shape) + if self.memory_layer is not None: + self.memory_layer.build(memory_input_shape) + # dtype of the layer is known at this moment, create the score_mask_value if + # needed. + self.score_mask_value = dtypes.as_dtype(self.dtype).as_numpy_dtype(-np.inf) + self.built = True + + def _setup_memory(self, memory, memory_mask=None): + """Pre-process the memory before actually query the memory. + + This should only be called once at the first invocation of call(). + + Args: + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_mask: The boolean tensor with shape `[batch_size, max_time]`. For + any value equal to False, the corresponding value in memory should be + ignored. + """ + if self._memory_initialized: + raise ValueError("The memory for the attention has already been setup.") + with ops.name_scope( + self.name, "BaseAttentionMechanismInit", nest.flatten(memory)): + self.values = _prepare_memory( + memory, memory_mask=memory_mask, + check_inner_dims_defined=self._check_inner_dims_defined) + if self.memory_layer is not None: + self.keys = self.memory_layer(self.values) + else: + self.keys = self.values + self.batch_size = ( + tensor_shape.dimension_value(self.keys.shape[0]) or + array_ops.shape(self.keys)[0]) + self._alignments_size = (tensor_shape.dimension_value(self.keys.shape[1]) + or array_ops.shape(self.keys)[1]) + if memory_mask is not None: + self.probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda + self.probability_fn(_maybe_mask_score( + score, self.score_mask_value, memory_mask=memory_mask), prev)) + self._memory_initialized = True + + def call(self, inputs, mask=None, **kwargs): + """Base method to calculate the attention score. + + Args: + inputs: a list of tensor that contains `query`, `state`, and `memory`. + `query` is the tensor of dtype matching `memory` and shape + `[batch_size, query_depth]`. + `state` is the tensor of dtype matching `memory` and shape + `[batch_size, alignments_size]`. (`alignments_size` is memory's + `max_time`). + `memory` is the memory to query; usually the output of an RNN encoder. + This tensor should be shaped `[batch_size, max_time, feature]`. + mask: optional bool tensor with shape `[batch, max_time]` for the mask of + memory. If it is not None, the corresponding item of the memory should + be filtered out during calculation. + **kwargs: Dict, other keyword arguments for the call method. + """ + query, state, memory, memory_mask = self._process_inputs(inputs, mask) + if not self._memory_initialized: + self._setup_memory(memory, memory_mask=memory_mask) + return self.calculate_attention(query, state) + + def calculate_attention(self, query, state): + raise NotImplementedError( + "calculate_attention need to be implemented by subclasses.") + + def get_config(self): + config = {} + # Since the probability_fn is likely to be a wrapped function, the child + # class should preserve the original function and how its wrapped. + + if self.query_layer is not None: + config["query_layer"] = { + "class_name": self.query_layer.__class__.__name__, + "config": self.query_layer.get_config(), + } + if self.memory_layer is not None: + config["memory_layer"] = { + "class_name": self.memory_layer.__class__.__name__, + "config": self.memory_layer.get_config(), + } + base_config = super(_BaseAttentionMechanismV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def _process_inputs(self, inputs, mask): + if len(inputs) != 3: + raise ValueError( + "Expect to have 3 inputs for attention, got %d" % len(inputs)) + query, state, memory = inputs + return query, state, memory, mask + + def _process_probability_fn(self, func_name): + """Helper method to retrieve the probably function by string input.""" + valid_probability_fns = { + "softmax": nn_ops.softmax, + "hardmax": hardmax, + } + if func_name not in valid_probability_fns.keys(): + raise ValueError("Invalid probability function: %s, options are %s" % + (func_name, valid_probability_fns.keys())) + return valid_probability_fns[func_name] + + @classmethod + def deserialize_inner_layer_from_config(cls, config, custom_objects): + """Helper method that reconstruct the query and memory from the config. + + In the get_config() method, the query and memory layer configs are + serialized into dict for persistence, this method perform the reverse action + to reconstruct the layer from the config. + + Args: + config: dict, the configs that will be used to reconstruct the object. + custom_objects: dict mapping class names (or function names) of custom + (non-Keras) objects to class/functions. + Returns: + config: dict, the config with layer instance created, which is ready to be + used as init parameters. + """ + # Reconstruct the query and memory layer for parent class. + from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + # Instead of updating the input, create a copy and use that. + config = config.copy() + query_layer_config = config.pop("query_layer", None) + if query_layer_config: + query_layer = deserialize_layer(query_layer_config, + custom_objects=custom_objects) + config["query_layer"] = query_layer + memory_layer_config = config.pop("memory_layer", None) + if memory_layer_config: + memory_layer = deserialize_layer(memory_layer_config, + custom_objects=custom_objects) + config["memory_layer"] = memory_layer + return config + + @property + def alignments_size(self): + return self._alignments_size + + def _luong_score(query, keys, scale): """Implements Luong-style (multiplicative) scoring function. @@ -304,7 +437,7 @@ def _luong_score(query, keys, scale): Args: query: Tensor, shape `[batch_size, num_units]` to compare to keys. keys: Processed memory, shape `[batch_size, max_time, num_units]`. - scale: Whether to apply a scale to the score function. + scale: the optional tensor to scale the attention score. Returns: A `[batch_size, max_time]` tensor of unnormalized score values. @@ -320,7 +453,6 @@ def _luong_score(query, keys, scale): "Query (%s) has units: %s. Keys (%s) have units: %s. " "Perhaps you need to set num_units to the keys' dimension (%s)?" % (query, depth, keys, key_units, key_units)) - dtype = query.dtype # Reshape from [batch_size, depth] to [batch_size, 1, depth] # for matmul. @@ -338,12 +470,8 @@ def _luong_score(query, keys, scale): score = math_ops.matmul(query, keys, transpose_b=True) score = array_ops.squeeze(score, [1]) - if scale: - # Scalar used in weight scaling - g = variable_scope.get_variable( - "attention_g", dtype=dtype, - initializer=init_ops.ones_initializer, shape=()) - score = g * score + if scale is not None: + score = scale * score return score @@ -354,8 +482,8 @@ class LuongAttention(_BaseAttentionMechanism): as described in: Minh-Thang Luong, Hieu Pham, Christopher D. Manning. - "Effective Approaches to Attention-based Neural Machine Translation." - EMNLP 2015. https://arxiv.org/abs/1508.04025 + [Effective Approaches to Attention-based Neural Machine Translation. + EMNLP 2015.](https://arxiv.org/abs/1508.04025) The second is the scaled form inspired partly by the normalized form of Bahdanau attention. @@ -429,13 +557,125 @@ class LuongAttention(_BaseAttentionMechanism): `max_time`). """ with variable_scope.variable_scope(None, "luong_attention", [query]): - score = _luong_score(query, self._keys, self._scale) + attention_g = None + if self._scale: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.ones_initializer, shape=()) + score = _luong_score(query, self._keys, attention_g) alignments = self._probability_fn(score, state) next_state = alignments return alignments, next_state -def _bahdanau_score(processed_query, keys, normalize): +class LuongAttentionV2(_BaseAttentionMechanismV2): + """Implements Luong-style (multiplicative) attention scoring. + + This attention has two forms. The first is standard Luong attention, + as described in: + + Minh-Thang Luong, Hieu Pham, Christopher D. Manning. + [Effective Approaches to Attention-based Neural Machine Translation. + EMNLP 2015.](https://arxiv.org/abs/1508.04025) + + The second is the scaled form inspired partly by the normalized form of + Bahdanau attention. + + To enable the second form, construct the object with parameter + `scale=True`. + """ + + def __init__(self, + units, + scale=False, + probability_fn="softmax", + dtype=None, + name="LuongAttention", + **kwargs): + """Construct the AttentionMechanism mechanism. + + Args: + units: The depth of the attention mechanism. + scale: Python boolean. Whether to scale the energy term. + probability_fn: (optional) string, the name of function to convert the + attention score to probabilities. The default is `softmax` which is + `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within + this module. Any other value will result intovalidation error. Default + to use `softmax`. + dtype: The data type for the memory layer of the attention mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + # For LuongAttention, we only transform the memory layer; thus + # num_units **must** match expected the query depth. + self.probability_fn_name = probability_fn + probability_fn = self._process_probability_fn(self.probability_fn_name) + wrapped_probability_fn = lambda score, _: probability_fn(score) + if dtype is None: + dtype = dtypes.float32 + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + super(LuongAttentionV2, self).__init__( + query_layer=None, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + self.units = units + self.scale = scale + + def build(self, input_shape): + super(LuongAttentionV2, self).build(input_shape) + if self.scale: + self.scale_weight = self.add_weight( + "attention_g", initializer=init_ops.ones_initializer, shape=()) + else: + self.scale_weight = None + self.built = True + + def calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + next_state: Same as the alignments. + """ + score = _luong_score(query, self.keys, self.scale_weight) + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "scale": self.scale, + "probability_fn": self.probability_fn_name, + } + base_config = super(LuongAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + +def _bahdanau_score(processed_query, keys, attention_v, + attention_g=None, attention_b=None): """Implements Bahdanau-style (additive) scoring function. This attention has two forms. The first is Bhandanau attention, @@ -453,41 +693,28 @@ def _bahdanau_score(processed_query, keys, normalize): Training of Deep Neural Networks." https://arxiv.org/abs/1602.07868 - To enable the second form, set `normalize=True`. + To enable the second form, set please pass in attention_g and attention_b. Args: processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys. keys: Processed memory, shape `[batch_size, max_time, num_units]`. - normalize: Whether to normalize the score function. + attention_v: Tensor, shape `[num_units]`. + attention_g: Optional scalar tensor for normalization. + attention_b: Optional tensor with shape `[num_units]` for normalization. Returns: A `[batch_size, max_time]` tensor of unnormalized score values. """ - dtype = processed_query.dtype - # Get the number of hidden units from the trailing dimension of keys - num_units = tensor_shape.dimension_value( - keys.shape[2]) or array_ops.shape(keys)[2] # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. processed_query = array_ops.expand_dims(processed_query, 1) - v = variable_scope.get_variable( - "attention_v", [num_units], dtype=dtype) - if normalize: - # Scalar used in weight normalization - g = variable_scope.get_variable( - "attention_g", dtype=dtype, - initializer=init_ops.constant_initializer(math.sqrt((1. / num_units))), - shape=()) - # Bias added prior to the nonlinearity - b = variable_scope.get_variable( - "attention_b", [num_units], dtype=dtype, - initializer=init_ops.zeros_initializer()) - # normed_v = g * v / ||v|| - normed_v = g * v * math_ops.rsqrt( - math_ops.reduce_sum(math_ops.square(v))) + if attention_g is not None and attention_b is not None: + normed_v = attention_g * attention_v * math_ops.rsqrt( + math_ops.reduce_sum(math_ops.square(attention_v))) return math_ops.reduce_sum( - normed_v * math_ops.tanh(keys + processed_query + b), [2]) + normed_v * math_ops.tanh(keys + processed_query + attention_b), [2]) else: - return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2]) + return math_ops.reduce_sum( + attention_v * math_ops.tanh(keys + processed_query), [2]) class BahdanauAttention(_BaseAttentionMechanism): @@ -578,12 +805,152 @@ class BahdanauAttention(_BaseAttentionMechanism): """ with variable_scope.variable_scope(None, "bahdanau_attention", [query]): processed_query = self.query_layer(query) if self.query_layer else query - score = _bahdanau_score(processed_query, self._keys, self._normalize) + attention_v = variable_scope.get_variable( + "attention_v", [self._num_units], dtype=query.dtype) + if not self._normalize: + attention_g = None + attention_b = None + else: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.constant_initializer( + math.sqrt((1. / self._num_units))), + shape=()) + attention_b = variable_scope.get_variable( + "attention_b", [self._num_units], dtype=query.dtype, + initializer=init_ops.zeros_initializer()) + + score = _bahdanau_score(processed_query, self._keys, attention_v, + attention_g=attention_g, attention_b=attention_b) alignments = self._probability_fn(score, state) next_state = alignments return alignments, next_state +class BahdanauAttentionV2(_BaseAttentionMechanismV2): + """Implements Bahdanau-style (additive) attention. + + This attention has two forms. The first is Bahdanau attention, + as described in: + + Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. + "Neural Machine Translation by Jointly Learning to Align and Translate." + ICLR 2015. https://arxiv.org/abs/1409.0473 + + The second is the normalized form. This form is inspired by the + weight normalization article: + + Tim Salimans, Diederik P. Kingma. + "Weight Normalization: A Simple Reparameterization to Accelerate + Training of Deep Neural Networks." + https://arxiv.org/abs/1602.07868 + + To enable the second form, construct the object with parameter + `normalize=True`. + """ + + def __init__(self, + units, + normalize=False, + probability_fn="softmax", + dtype=None, + name="BahdanauAttention", + **kwargs): + """Construct the Attention mechanism. + + Args: + units: The depth of the query mechanism. + normalize: Python boolean. Whether to normalize the energy term. + probability_fn: (optional) string, the name of function to convert the + attention score to probabilities. The default is `softmax` which is + `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within + this module. Any other value will result into validation error. Default + to use `softmax`. + dtype: The data type for the query and memory layers of the attention + mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + self.probability_fn_name = probability_fn + probability_fn = self._process_probability_fn(self.probability_fn_name) + wrapped_probability_fn = lambda score, _: probability_fn(score) + if dtype is None: + dtype = dtypes.float32 + query_layer = kwargs.pop("query_layer", None) + if not query_layer: + query_layer = layers.Dense( + units, name="query_layer", use_bias=False, dtype=dtype) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + super(BahdanauAttentionV2, self).__init__( + query_layer=query_layer, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + self.units = units + self.normalize = normalize + + def build(self, input_shape): + super(BahdanauAttentionV2, self).build(input_shape) + self.attention_v = self.add_weight( + "attention_v", [self.units], dtype=self.dtype) + if self.normalize: + self.attention_g = self.add_weight( + "attention_g", initializer=init_ops.constant_initializer( + math.sqrt((1. / self.units))), shape=()) + self.attention_b = self.add_weight( + "attention_b", shape=[self.units], + initializer=init_ops.zeros_initializer()) + else: + self.attention_g = None + self.attention_b = None + self.built = True + + def calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + next_state: same as alignments. + """ + processed_query = self.query_layer(query) if self.query_layer else query + score = _bahdanau_score(processed_query, self.keys, self.attention_v, + attention_g=self.attention_g, + attention_b=self.attention_b) + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "normalize": self.normalize, + "probability_fn": self.probability_fn_name, + } + base_config = super(BahdanauAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + def safe_cumprod(x, *args, **kwargs): """Computes cumprod of x in logspace using cumsum to avoid underflow. @@ -766,6 +1133,34 @@ class _BaseMonotonicAttentionMechanism(_BaseAttentionMechanism): dtype=dtype) +class _BaseMonotonicAttentionMechanismV2(_BaseAttentionMechanismV2): + """Base attention mechanism for monotonic attention. + + Simply overrides the initial_alignments function to provide a dirac + distribution, which is needed in order for the monotonic attention + distributions to have the correct behavior. + """ + + def initial_alignments(self, batch_size, dtype): + """Creates the initial alignment values for the monotonic attentions. + + Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] + for all entries in the batch. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A `dtype` tensor shaped `[batch_size, alignments_size]` + (`alignments_size` is the values' `max_time`). + """ + max_time = self._alignments_size + return array_ops.one_hot( + array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, + dtype=dtype) + + class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): """Monotonic attention mechanism with Bahadanau-style energy function. @@ -860,7 +1255,22 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): with variable_scope.variable_scope( None, "bahdanau_monotonic_attention", [query]): processed_query = self.query_layer(query) if self.query_layer else query - score = _bahdanau_score(processed_query, self._keys, self._normalize) + attention_v = variable_scope.get_variable( + "attention_v", [self._num_units], dtype=query.dtype) + if not self._normalize: + attention_g = None + attention_b = None + else: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.constant_initializer( + math.sqrt((1. / self._num_units))), + shape=()) + attention_b = variable_scope.get_variable( + "attention_b", [self._num_units], dtype=query.dtype, + initializer=init_ops.zeros_initializer()) + score = _bahdanau_score(processed_query, self._keys, attention_v, + attention_g=attention_g, attention_b=attention_b) score_bias = variable_scope.get_variable( "attention_score_bias", dtype=processed_query.dtype, initializer=self._score_bias_init) @@ -870,6 +1280,146 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): return alignments, next_state +class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): + """Monotonic attention mechanism with Bahadanau-style energy function. + + This type of attention enforces a monotonic constraint on the attention + distributions; that is once the model attends to a given point in the memory + it can't attend to any prior points at subsequence output timesteps. It + achieves this by using the _monotonic_probability_fn instead of softmax to + construct its attention distributions. Since the attention scores are passed + through a sigmoid, a learnable scalar bias parameter is applied after the + score function and before the sigmoid. Otherwise, it is equivalent to + BahdanauAttention. This approach is proposed in + + Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, + "Online and Linear-Time Attention by Enforcing Monotonic Alignments." + ICML 2017. https://arxiv.org/abs/1704.00784 + """ + + def __init__(self, + units, + normalize=False, + sigmoid_noise=0., + sigmoid_noise_seed=None, + score_bias_init=0., + mode="parallel", + dtype=None, + name="BahdanauMonotonicAttention", + **kwargs): + """Construct the Attention mechanism. + + Args: + units: The depth of the query mechanism. + normalize: Python boolean. Whether to normalize the energy term. + sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring + for `_monotonic_probability_fn` for more information. + sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. + score_bias_init: Initial value for score bias scalar. It's recommended to + initialize this to a negative value when the length of the memory is + large. + mode: How to compute the attention distribution. Must be one of + 'recursive', 'parallel', or 'hard'. See the docstring for + `tf.contrib.seq2seq.monotonic_attention` for more information. + dtype: The data type for the query and memory layers of the attention + mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + # Set up the monotonic probability fn with supplied parameters + if dtype is None: + dtype = dtypes.float32 + wrapped_probability_fn = functools.partial( + _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, + seed=sigmoid_noise_seed) + query_layer = kwargs.pop("query_layer", None) + if not query_layer: + query_layer = layers.Dense( + units, name="query_layer", use_bias=False, dtype=dtype) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + super(BahdanauMonotonicAttentionV2, self).__init__( + query_layer=query_layer, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + self.units = units + self.normalize = normalize + self.sigmoid_noise = sigmoid_noise + self.sigmoid_noise_seed = sigmoid_noise_seed + self.score_bias_init = score_bias_init + self.mode = mode + + def build(self, input_shape): + super(BahdanauMonotonicAttentionV2, self).build(input_shape) + self.attention_v = self.add_weight( + "attention_v", [self.units], dtype=self.dtype) + self.attention_score_bias = self.add_weight( + "attention_score_bias", shape=(), dtype=self.dtype, + initializer=init_ops.constant_initializer( + self.score_bias_init, dtype=self.dtype)) + if not self.normalize: + self.attention_g = None + self.attention_b = None + else: + self.attention_g = self.add_weight( + "attention_g", dtype=self.dtype, + initializer=init_ops.constant_initializer( + math.sqrt((1. / self.units))), + shape=()) + self.attention_b = self.add_weight( + "attention_b", [self.units], dtype=self.dtype, + initializer=init_ops.zeros_initializer()) + self.built = True + + def calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + """ + processed_query = self.query_layer(query) if self.query_layer else query + score = _bahdanau_score(processed_query, self.keys, self.attention_v, + attention_g=self.attention_g, + attention_b=self.attention_b) + score += self.attention_score_bias + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "normalize": self.normalize, + "sigmoid_noise": self.sigmoid_noise, + "sigmoid_noise_seed": self.sigmoid_noise_seed, + "score_bias_init": self.score_bias_init, + "mode": self.mode, + } + base_config = super(BahdanauMonotonicAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): """Monotonic attention mechanism with Luong-style energy function. @@ -960,7 +1510,12 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): """ with variable_scope.variable_scope(None, "luong_monotonic_attention", [query]): - score = _luong_score(query, self._keys, self._scale) + attention_g = None + if self._scale: + attention_g = variable_scope.get_variable( + "attention_g", dtype=query.dtype, + initializer=init_ops.ones_initializer, shape=()) + score = _luong_score(query, self._keys, attention_g) score_bias = variable_scope.get_variable( "attention_score_bias", dtype=query.dtype, initializer=self._score_bias_init) @@ -970,6 +1525,129 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): return alignments, next_state +class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): + """Monotonic attention mechanism with Luong-style energy function. + + This type of attention enforces a monotonic constraint on the attention + distributions; that is once the model attends to a given point in the memory + it can't attend to any prior points at subsequence output timesteps. It + achieves this by using the _monotonic_probability_fn instead of softmax to + construct its attention distributions. Otherwise, it is equivalent to + LuongAttention. This approach is proposed in + + [Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, + "Online and Linear-Time Attention by Enforcing Monotonic Alignments." + ICML 2017.](https://arxiv.org/abs/1704.00784) + """ + + def __init__(self, + units, + scale=False, + sigmoid_noise=0., + sigmoid_noise_seed=None, + score_bias_init=0., + mode="parallel", + dtype=None, + name="LuongMonotonicAttention", + **kwargs): + """Construct the Attention mechanism. + + Args: + units: The depth of the query mechanism. + scale: Python boolean. Whether to scale the energy term. + sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring + for `_monotonic_probability_fn` for more information. + sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. + score_bias_init: Initial value for score bias scalar. It's recommended to + initialize this to a negative value when the length of the memory is + large. + mode: How to compute the attention distribution. Must be one of + 'recursive', 'parallel', or 'hard'. See the docstring for + `tf.contrib.seq2seq.monotonic_attention` for more information. + dtype: The data type for the query and memory layers of the attention + mechanism. + name: Name to use when creating ops. + **kwargs: Dictionary that contains other common arguments for layer + creation. + """ + # Set up the monotonic probability fn with supplied parameters + if dtype is None: + dtype = dtypes.float32 + wrapped_probability_fn = functools.partial( + _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, + seed=sigmoid_noise_seed) + memory_layer = kwargs.pop("memory_layer", None) + if not memory_layer: + memory_layer = layers.Dense( + units, name="memory_layer", use_bias=False, dtype=dtype) + super(LuongMonotonicAttentionV2, self).__init__( + query_layer=None, + memory_layer=memory_layer, + probability_fn=wrapped_probability_fn, + name=name, + dtype=dtype, + **kwargs) + self.units = units + self.scale = scale + self.sigmoid_noise = sigmoid_noise + self.sigmoid_noise_seed = sigmoid_noise_seed + self.score_bias_init = score_bias_init + self.mode = mode + + def build(self, input_shape): + super(LuongMonotonicAttentionV2, self).build(input_shape) + if self.scale: + self.attention_g = self.add_weight( + "attention_g", initializer=init_ops.ones_initializer, shape=()) + else: + self.attention_g = None + self.attention_score_bias = self.add_weight( + "attention_score_bias", shape=(), + initializer=init_ops.constant_initializer( + self.score_bias_init, dtype=self.dtype)) + self.built = True + + def calculate_attention(self, query, state): + """Score the query based on the keys and values. + + Args: + query: Tensor of dtype matching `self.values` and shape + `[batch_size, query_depth]`. + state: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` + (`alignments_size` is memory's `max_time`). + + Returns: + alignments: Tensor of dtype matching `self.values` and shape + `[batch_size, alignments_size]` (`alignments_size` is memory's + `max_time`). + next_state: Same as alignments + """ + score = _luong_score(query, self.keys, self.attention_g) + score += self.attention_score_bias + alignments = self.probability_fn(score, state) + next_state = alignments + return alignments, next_state + + def get_config(self): + config = { + "units": self.units, + "scale": self.scale, + "sigmoid_noise": self.sigmoid_noise, + "sigmoid_noise_seed": self.sigmoid_noise_seed, + "score_bias_init": self.score_bias_init, + "mode": self.mode, + } + base_config = super(LuongMonotonicAttentionV2, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( + config, custom_objects=custom_objects) + return cls(**config) + + class AttentionWrapperState( collections.namedtuple("AttentionWrapperState", ("cell_state", "attention", "time", "alignments", @@ -1026,6 +1704,97 @@ class AttentionWrapperState( super(AttentionWrapperState, self)._replace(**kwargs)) +def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, + check_inner_dims_defined=True): + """Convert to tensor and possibly mask `memory`. + + Args: + memory: `Tensor`, shaped `[batch_size, max_time, ...]`. + memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. + memory_mask: `boolean` tensor with shape [batch_size, max_time]. The memory + should be skipped when the corresponding mask is False. + check_inner_dims_defined: Python boolean. If `True`, the `memory` + argument's shape is checked to ensure all but the two outermost + dimensions are fully defined. + + Returns: + A (possibly masked), checked, new `memory`. + + Raises: + ValueError: If `check_inner_dims_defined` is `True` and not + `memory.shape[2:].is_fully_defined()`. + """ + memory = nest.map_structure( + lambda m: ops.convert_to_tensor(m, name="memory"), memory) + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError("memory_sequence_length and memory_mask can't be provided " + "at same time.") + if memory_sequence_length is not None: + memory_sequence_length = ops.convert_to_tensor( + memory_sequence_length, name="memory_sequence_length") + if check_inner_dims_defined: + def _check_dims(m): + if not m.get_shape()[2:].is_fully_defined(): + raise ValueError("Expected memory %s to have fully defined inner dims, " + "but saw shape: %s" % (m.name, m.get_shape())) + nest.map_structure(_check_dims, memory) + if memory_sequence_length is None and memory_mask is None: + seq_len_mask = None + seq_len_batch_size = None + elif memory_sequence_length is not None: + seq_len_mask = array_ops.sequence_mask( + memory_sequence_length, + maxlen=array_ops.shape(nest.flatten(memory)[0])[1], + dtype=nest.flatten(memory)[0].dtype) + seq_len_batch_size = ( + tensor_shape.dimension_value(memory_sequence_length.shape[0]) + or array_ops.shape(memory_sequence_length)[0]) + else: + # For memory_mask is not None + seq_len_mask = memory_mask + seq_len_batch_size = ( + tensor_shape.dimension_value(memory_mask.shape[0]) + or array_ops.shape(memory_mask)[0]) + def _maybe_mask(m, seq_len_mask): + """Mask the memory based on the memory mask.""" + rank = m.get_shape().ndims + rank = rank if rank is not None else array_ops.rank(m) + extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) + m_batch_size = tensor_shape.dimension_value( + m.shape[0]) or array_ops.shape(m)[0] + if seq_len_batch_size is not None: + message = ("memory_sequence_length and memory tensor batch sizes do not " + "match.") + with ops.control_dependencies([ + check_ops.assert_equal( + seq_len_batch_size, m_batch_size, message=message)]): + seq_len_mask = array_ops.reshape( + seq_len_mask, + array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) + return m * seq_len_mask + else: + return m + return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) + + +def _maybe_mask_score(score, memory_sequence_length=None, memory_mask=None, + score_mask_value=None): + """Mask the attention score based on the masks.""" + if memory_sequence_length is None and memory_mask is None: + return score + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError("memory_sequence_length and memory_mask can't be provided " + "at same time.") + if memory_sequence_length is not None: + message = "All values in memory_sequence_length must greater than zero." + with ops.control_dependencies( + [check_ops.assert_positive(memory_sequence_length, message=message)]): + memory_mask = array_ops.sequence_mask( + memory_sequence_length, maxlen=array_ops.shape(score)[1]) + score_mask_values = score_mask_value * array_ops.ones_like(score) + return array_ops.where(memory_mask, score, score_mask_values) + + def hardmax(logits, name=None): """Returns batched one-hot vectors. @@ -1088,7 +1857,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): output_attention=True, initial_cell_state=None, name=None, - attention_layer=None): + attention_layer=None, + attention_fn=None): """Construct the `AttentionWrapper`. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in @@ -1132,7 +1902,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. If - attention_layer is set, this must be None. + attention_layer is set, this must be None. If attention_fn is set, + it must guaranteed that the outputs of attention_fn also meet the + above requirements. alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`). @@ -1158,6 +1930,12 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): the context as attention at each time step. If attention_mechanism is a list, attention_layer must be a list of the same length. If attention_layers_size is set, this must be None. + attention_fn: An optional callable function that allows users to provide + their own customized attention function, which takes input + (attention_mechanism, cell_output, attention_state, attention_layer) and + outputs (attention, alignments, next_attention_state). If provided, + the attention_layer_size should be the size of the outputs of + attention_fn. Raises: TypeError: `attention_layer_size` is not None and (`attention_mechanism` @@ -1240,6 +2018,10 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): tensor_shape.dimension_value(attention_mechanism.values.shape[-1]) for attention_mechanism in attention_mechanisms) + if attention_fn is None: + attention_fn = _compute_attention + self._attention_fn = attention_fn + self._cell = cell self._attention_mechanisms = attention_mechanisms self._cell_input_fn = cell_input_fn @@ -1443,7 +2225,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): all_attention_states = [] maybe_all_histories = [] for i, attention_mechanism in enumerate(self._attention_mechanisms): - attention, alignments, next_attention_state = _compute_attention( + attention, alignments, next_attention_state = self._attention_fn( attention_mechanism, cell_output, previous_attention_state[i], self._attention_layers[i] if self._attention_layers else None) alignment_history = previous_alignment_history[i].write( diff --git a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py index 7eb95e5a70de985dca0d4b565ba03bdf454b6161..16dfa7ed8268d761dee49ec0146efabcaaef1393 100644 --- a/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/basic_decoder.py @@ -23,8 +23,10 @@ import collections from tensorflow.contrib.seq2seq.python.ops import decoder from tensorflow.contrib.seq2seq.python.ops import helper as helper_py +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import layers from tensorflow.python.layers import base as layers_base from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.util import nest @@ -146,3 +148,102 @@ class BasicDecoder(decoder.Decoder): sample_ids=sample_ids) outputs = BasicDecoderOutput(cell_outputs, sample_ids) return (outputs, next_state, next_inputs, finished) + + +class BasicDecoderV2(decoder.BaseDecoder): + """Basic sampling decoder.""" + + def __init__(self, cell, sampler, output_layer=None, **kwargs): + """Initialize BasicDecoder. + + Args: + cell: An `RNNCell` instance. + sampler: A `Sampler` instance. + output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., + `tf.layers.Dense`. Optional layer to apply to the RNN output prior to + storing the result or sampling. + **kwargs: Other keyward arguments for layer creation. + + Raises: + TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. + """ + rnn_cell_impl.assert_like_rnncell("cell", cell) + if not isinstance(sampler, sampler_py.Sampler): + raise TypeError("sampler must be a Sampler, received: %s" % (sampler,)) + if (output_layer is not None and + not isinstance(output_layer, layers.Layer)): + raise TypeError( + "output_layer must be a Layer, received: %s" % (output_layer,)) + self.cell = cell + self.sampler = sampler + self.output_layer = output_layer + super(BasicDecoderV2, self).__init__(**kwargs) + + def initialize(self, inputs, initial_state=None, **kwargs): + """Initialize the decoder.""" + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + self._cell_dtype = nest.flatten(initial_state)[0].dtype + return self.sampler.initialize(inputs, **kwargs) + (initial_state,) + + @property + def batch_size(self): + return self.sampler.batch_size + + def _rnn_output_size(self): + size = tensor_shape.TensorShape(self.cell.output_size) + if self.output_layer is None: + return size + else: + # To use layer's compute_output_shape, we need to convert the + # RNNCell's output_size entries into shapes with an unknown + # batch size. We then pass this through the layer's + # compute_output_shape and read off all but the first (batch) + # dimensions to get the output size of the rnn with the layer + # applied to the top. + output_shape_with_unknown_batch = nest.map_structure( + lambda s: tensor_shape.TensorShape([None]).concatenate(s), size) + layer_output_shape = self.output_layer.compute_output_shape( + output_shape_with_unknown_batch) + return nest.map_structure(lambda s: s[1:], layer_output_shape) + + @property + def output_size(self): + # Return the cell output and the id + return BasicDecoderOutput( + rnn_output=self._rnn_output_size(), + sample_id=self.sampler.sample_ids_shape) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and the sample_ids_dtype from the helper. + dtype = self._cell_dtype + return BasicDecoderOutput( + nest.map_structure(lambda _: dtype, self._rnn_output_size()), + self.sampler.sample_ids_dtype) + + def step(self, time, inputs, state): + """Perform a decoding step. + + Args: + time: scalar `int32` tensor. + inputs: A (structure of) input tensors. + state: A (structure of) state tensors and TensorArrays. + + Returns: + `(outputs, next_state, next_inputs, finished)`. + """ + cell_outputs, cell_state = self.cell(inputs, state) + if self.output_layer is not None: + cell_outputs = self.output_layer(cell_outputs) + sample_ids = self.sampler.sample( + time=time, outputs=cell_outputs, state=cell_state) + (finished, next_inputs, next_state) = self.sampler.next_inputs( + time=time, + outputs=cell_outputs, + state=cell_state, + sample_ids=sample_ids) + outputs = BasicDecoderOutput(cell_outputs, sample_ids) + return (outputs, next_state, next_inputs, finished) diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index ab36848f13ab3078cd232c18f140188e12db703b..8f8f057702951094758b277ce060955f3dc6e99d 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -921,6 +921,7 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight, """ length_penalty_ = _length_penalty( sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight) + length_penalty_ = math_ops.cast(length_penalty_, dtype=log_probs.dtype) scores = log_probs / length_penalty_ coverage_penalty_weight = ops.convert_to_tensor( diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index f58268eff525a4b592c79acb32207e1a3f62bdc7..33f7bac8159401175ce57c0463fff1398c1dd9bb 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util @@ -135,6 +136,127 @@ class Decoder(object): return False +class BaseDecoder(layers.Layer): + """An RNN Decoder that is based on a Keras layer. + + Concepts used by this interface: + - `inputs`: (structure of) tensors and TensorArrays that is passed as input to + the RNNCell composing the decoder, at each time step. + - `state`: (structure of) tensors and TensorArrays that is passed to the + RNNCell instance as the state. + - `memory`: (sturecute of) tensors that is usually the full output of the + encoder, which will be used for the attention wrapper for the RNNCell. + - `finished`: boolean tensor telling whether each sequence in the batch is + finished. + - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each + time step. + """ + + def __init__(self, + output_time_major=False, + impute_finished=False, + maximum_iterations=None, + parallel_iterations=32, + swap_memory=False, + **kwargs): + self.output_time_major = output_time_major + self.impute_finished = impute_finished + self.maximum_iterations = maximum_iterations + self.parallel_iterations = parallel_iterations + self.swap_memory = swap_memory + super(BaseDecoder, self).__init__(**kwargs) + + def call(self, inputs, initial_state=None, **kwargs): + init_kwargs = kwargs + init_kwargs["initial_state"] = initial_state + return dynamic_decode(self, + output_time_major=self.output_time_major, + impute_finished=self.impute_finished, + maximum_iterations=self.maximum_iterations, + parallel_iterations=self.parallel_iterations, + swap_memory=self.swap_memory, + decoder_init_input=inputs, + decoder_init_kwargs=init_kwargs) + + @property + def batch_size(self): + """The batch size of input values.""" + raise NotImplementedError + + @property + def output_size(self): + """A (possibly nested tuple of...) integer[s] or `TensorShape` object[s].""" + raise NotImplementedError + + @property + def output_dtype(self): + """A (possibly nested tuple of...) dtype[s].""" + raise NotImplementedError + + def initialize(self, inputs, initial_state=None, **kwargs): + """Called before any decoding iterations. + + This methods must compute initial input values and initial state. + + Args: + inputs: (structure of) tensors that contains the input for the decoder. In + the normal case, its a tensor with shape [batch, timestep, embedding]. + initial_state: (structure of) tensors that contains the initial state for + the RNNCell. + **kwargs: Other arguments that are passed in from layer.call() method. It + could contains item like input sequence_length, or masking for input. + + Returns: + `(finished, initial_inputs, initial_state)`: initial values of + 'finished' flags, inputs and state. + """ + raise NotImplementedError + + def step(self, time, inputs, state): + """Called per step of decoding (but only once for dynamic decoding). + + Args: + time: Scalar `int32` tensor. Current step number. + inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time + step. + state: RNNCell state (possibly nested tuple of) tensor[s] from previous + time step. + + Returns: + `(outputs, next_state, next_inputs, finished)`: `outputs` is an object + containing the decoder output, `next_state` is a (structure of) state + tensors and TensorArrays, `next_inputs` is the tensor that should be used + as input for the next step, `finished` is a boolean tensor telling whether + the sequence is complete, for each sequence in the batch. + """ + raise NotImplementedError + + def finalize(self, outputs, final_state, sequence_lengths): + raise NotImplementedError + + @property + def tracks_own_finished(self): + """Describes whether the Decoder keeps track of finished states. + + Most decoders will emit a true/false `finished` value independently + at each time step. In this case, the `dynamic_decode` function keeps track + of which batch entries are already finished, and performs a logical OR to + insert new batches to the finished set. + + Some decoders, however, shuffle batches / beams between time steps and + `dynamic_decode` will mix up the finished state across these entries because + it does not track the reshuffle across time steps. In this case, it is + up to the decoder to declare that it will keep track of its own finished + state by setting this property to `True`. + + Returns: + Python bool. + """ + return False + + # TODO(scottzhu): Add build/get_config/from_config and other layer methods. + + def _create_zero_outputs(size, dtype, batch_size): """Create a zero outputs Tensor structure.""" def _create(s, d): @@ -149,7 +271,8 @@ def dynamic_decode(decoder, maximum_iterations=None, parallel_iterations=32, swap_memory=False, - scope=None): + scope=None, + **kwargs): """Perform dynamic decoding with `decoder`. Calls initialize() once and step() repeatedly on the Decoder object. @@ -171,6 +294,9 @@ def dynamic_decode(decoder, parallel_iterations: Argument passed to `tf.while_loop`. swap_memory: Argument passed to `tf.while_loop`. scope: Optional variable scope to use. + **kwargs: dict, other keyword arguments for dynamic_decode. It might contain + arguments for `BaseDecoder` to initialize, which takes all tensor inputs + during call(). Returns: `(final_outputs, final_state, final_sequence_lengths)`. @@ -179,7 +305,7 @@ def dynamic_decode(decoder, TypeError: if `decoder` is not an instance of `Decoder`. ValueError: if `maximum_iterations` is provided but is not a scalar. """ - if not isinstance(decoder, Decoder): + if not isinstance(decoder, (Decoder, BaseDecoder)): raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) @@ -204,7 +330,14 @@ def dynamic_decode(decoder, if maximum_iterations.get_shape().ndims != 0: raise ValueError("maximum_iterations must be a scalar") - initial_finished, initial_inputs, initial_state = decoder.initialize() + if isinstance(decoder, Decoder): + initial_finished, initial_inputs, initial_state = decoder.initialize() + else: + # For BaseDecoder that takes tensor inputs during call. + decoder_init_input = kwargs.pop("decoder_init_input", None) + decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {}) + initial_finished, initial_inputs, initial_state = decoder.initialize( + decoder_init_input, **decoder_init_kwargs) zero_outputs = _create_zero_outputs(decoder.output_size, decoder.output_dtype, @@ -222,7 +355,7 @@ def dynamic_decode(decoder, def _shape(batch_size, from_shape): if (not isinstance(from_shape, tensor_shape.TensorShape) or from_shape.ndims == 0): - return tensor_shape.TensorShape(None) + return None else: batch_size = tensor_util.constant_value( ops.convert_to_tensor( diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index 3245cc5e72154289ea3ba000b9a30586a7ad03a9..033c2eb0801d5a51ee937f5e960faa91a6f1ae54 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -32,9 +32,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops.distributions import bernoulli -from tensorflow.python.ops.distributions import categorical from tensorflow.python.util import nest __all__ = [ @@ -51,6 +50,68 @@ __all__ = [ _transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access +# The following sample functions (_call_sampler, bernoulli_sample, +# categorical_sample) mimic TensorFlow Probability distribution semantics. + + +def _call_sampler(sample_n_fn, sample_shape, name=None): + """Reshapes vector of samples.""" + with ops.name_scope(name, "call_sampler", values=[sample_shape]): + sample_shape = ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32, name="sample_shape") + # Ensure sample_shape is a vector (vs just a scalar). + pad = math_ops.cast(math_ops.equal(array_ops.rank(sample_shape), 0), + dtypes.int32) + sample_shape = array_ops.reshape( + sample_shape, + array_ops.pad(array_ops.shape(sample_shape), + paddings=[[pad, 0]], + constant_values=1)) + samples = sample_n_fn(math_ops.reduce_prod(sample_shape)) + batch_event_shape = array_ops.shape(samples)[1:] + final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) + return array_ops.reshape(samples, final_shape) + + +def bernoulli_sample(probs=None, logits=None, dtype=dtypes.int32, + sample_shape=(), seed=None): + """Samples from Bernoulli distribution.""" + if probs is None: + probs = math_ops.sigmoid(logits, name="probs") + else: + probs = ops.convert_to_tensor(probs, name="probs") + batch_shape_tensor = array_ops.shape(probs) + def _sample_n(n): + """Sample vector of Bernoullis.""" + new_shape = array_ops.concat([[n], batch_shape_tensor], 0) + uniform = random_ops.random_uniform( + new_shape, seed=seed, dtype=probs.dtype) + return math_ops.cast(math_ops.less(uniform, probs), dtype) + return _call_sampler(_sample_n, sample_shape) + + +def categorical_sample(logits, dtype=dtypes.int32, + sample_shape=(), seed=None): + """Samples from categorical distribution.""" + logits = ops.convert_to_tensor(logits, name="logits") + event_size = array_ops.shape(logits)[-1] + batch_shape_tensor = array_ops.shape(logits)[:-1] + def _sample_n(n): + """Sample vector of categoricals.""" + if logits.shape.ndims == 2: + logits_2d = logits + else: + logits_2d = array_ops.reshape(logits, [-1, event_size]) + sample_dtype = dtypes.int64 if logits.dtype.size > 4 else dtypes.int32 + draws = random_ops.multinomial( + logits_2d, n, seed=seed, output_dtype=sample_dtype) + draws = array_ops.reshape( + array_ops.transpose(draws), + array_ops.concat([[n], batch_shape_tensor], 0)) + return math_ops.cast(draws, dtype) + return _call_sampler(_sample_n, sample_shape) + + def _unstack_ta(inp): return tensor_array_ops.TensorArray( dtype=inp.dtype, size=array_ops.shape(inp)[0], @@ -307,14 +368,14 @@ class ScheduledEmbeddingTrainingHelper(TrainingHelper): with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample", [time, outputs, state]): # Return -1s where we did not sample, and sample_ids elsewhere - select_sampler = bernoulli.Bernoulli( - probs=self._sampling_probability, dtype=dtypes.bool) - select_sample = select_sampler.sample( - sample_shape=self.batch_size, seed=self._scheduling_seed) - sample_id_sampler = categorical.Categorical(logits=outputs) + select_sample = bernoulli_sample( + probs=self._sampling_probability, + dtype=dtypes.bool, + sample_shape=self.batch_size, + seed=self._scheduling_seed) return array_ops.where( select_sample, - sample_id_sampler.sample(seed=self._seed), + categorical_sample(logits=outputs, seed=self._seed), gen_array_ops.fill([self.batch_size], -1)) def next_inputs(self, time, outputs, state, sample_ids, name=None): @@ -425,8 +486,10 @@ class ScheduledOutputTrainingHelper(TrainingHelper): def sample(self, time, outputs, state, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperSample", [time, outputs, state]): - sampler = bernoulli.Bernoulli(probs=self._sampling_probability) - return sampler.sample(sample_shape=self.batch_size, seed=self._seed) + return bernoulli_sample( + probs=self._sampling_probability, + sample_shape=self.batch_size, + seed=self._seed) def next_inputs(self, time, outputs, state, sample_ids, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs", @@ -610,8 +673,7 @@ class SampleEmbeddingHelper(GreedyEmbeddingHelper): else: logits = outputs / self._softmax_temperature - sample_id_sampler = categorical.Categorical(logits=logits) - sample_ids = sample_id_sampler.sample(seed=self._seed) + sample_ids = categorical_sample(logits=logits, seed=self._seed) return sample_ids diff --git a/tensorflow/contrib/seq2seq/python/ops/loss.py b/tensorflow/contrib/seq2seq/python/ops/loss.py index 39a6d2f58b140706a94d83273d3327edd1891368..0fbfd6187030f14ac105a18b3e09b7a42d4de32a 100644 --- a/tensorflow/contrib/seq2seq/python/ops/loss.py +++ b/tensorflow/contrib/seq2seq/python/ops/loss.py @@ -20,11 +20,12 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.keras.losses import Loss from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -__all__ = ["sequence_loss"] +__all__ = ["sequence_loss", "SequenceLoss"] def sequence_loss(logits, @@ -32,16 +33,26 @@ def sequence_loss(logits, weights, average_across_timesteps=True, average_across_batch=True, + sum_over_timesteps=False, + sum_over_batch=False, softmax_loss_function=None, name=None): """Weighted cross-entropy loss for a sequence of logits. - Depending on the values of `average_across_timesteps` and - `average_across_batch`, the return Tensor will have rank 0, 1, or 2 as these - arguments reduce the cross-entropy at each target, which has shape - `[batch_size, sequence_length]`, over their respective dimensions. For - example, if `average_across_timesteps` is `True` and `average_across_batch` - is `False`, then the return Tensor will have shape `[batch_size]`. + Depending on the values of `average_across_timesteps` / `sum_over_timesteps` + and `average_across_batch` / `sum_over_batch`, the return Tensor will have + rank 0, 1, or 2 as these arguments reduce the cross-entropy at each target, + which has shape `[batch_size, sequence_length]`, over their respective + dimensions. For example, if `average_across_timesteps` is `True` and + `average_across_batch` is `False`, then the return Tensor will have shape + `[batch_size]`. + + Note that `average_across_timesteps` and `sum_over_timesteps` cannot be True + at same time. Same for `average_across_batch` and `sum_over_batch`. + + The recommended loss reduction in tf 2.0 has been changed to sum_over, instead + of weighted average. User are recommend to use `sum_over_timesteps` and + `sum_over_batch` for reduction. Args: logits: A Tensor of shape @@ -58,6 +69,12 @@ def sequence_loss(logits, dimension and divide the cost by the total label weight across timesteps. average_across_batch: If set, sum the cost across the batch dimension and divide the returned cost by the batch size. + sum_over_timesteps: If set, sum the cost across the sequence dimension and + divide the size of the sequence. Note that any element with 0 weights will + be excluded from size calculation. + sum_over_batch: if set, sum the cost across the batch dimension and divide + the total cost by the batch size. Not that any element with 0 weights will + be excluded from size calculation. softmax_loss_function: Function (labels, logits) -> loss-batch to be used instead of the standard softmax (the default if this is None). **Note that to avoid confusion, it is required for the function to accept @@ -78,11 +95,15 @@ def sequence_loss(logits, raise ValueError("Logits must be a " "[batch_size x sequence_length x logits] tensor") if len(targets.get_shape()) != 2: - raise ValueError("Targets must be a [batch_size x sequence_length] " - "tensor") + raise ValueError("Targets must be a [batch_size x sequence_length] tensor") if len(weights.get_shape()) != 2: - raise ValueError("Weights must be a [batch_size x sequence_length] " - "tensor") + raise ValueError("Weights must be a [batch_size x sequence_length] tensor") + if average_across_timesteps and sum_over_timesteps: + raise ValueError("average_across_timesteps and sum_over_timesteps cannot " + "be set to True at same time.") + if average_across_batch and sum_over_batch: + raise ValueError("average_across_batch and sum_over_batch cannot be set " + "to True at same time.") with ops.name_scope(name, "sequence_loss", [logits, targets, weights]): num_classes = array_ops.shape(logits)[2] logits_flat = array_ops.reshape(logits, [-1, num_classes]) @@ -96,20 +117,56 @@ def sequence_loss(logits, if average_across_timesteps and average_across_batch: crossent = math_ops.reduce_sum(crossent) total_size = math_ops.reduce_sum(weights) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size + crossent = math_ops.div_no_nan(crossent, total_size) + elif sum_over_timesteps and sum_over_batch: + crossent = math_ops.reduce_sum(crossent) + total_count = math_ops.cast(math_ops.count_nonzero(weights), + crossent.dtype) + crossent = math_ops.div_no_nan(crossent, total_count) else: - batch_size = array_ops.shape(logits)[0] - sequence_length = array_ops.shape(logits)[1] - crossent = array_ops.reshape(crossent, [batch_size, sequence_length]) - if average_across_timesteps and not average_across_batch: - crossent = math_ops.reduce_sum(crossent, axis=[1]) - total_size = math_ops.reduce_sum(weights, axis=[1]) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size - if not average_across_timesteps and average_across_batch: - crossent = math_ops.reduce_sum(crossent, axis=[0]) - total_size = math_ops.reduce_sum(weights, axis=[0]) - total_size += 1e-12 # to avoid division by 0 for all-0 weights - crossent /= total_size + crossent = array_ops.reshape(crossent, array_ops.shape(logits)[0:2]) + if average_across_timesteps or average_across_batch: + reduce_axis = [0] if average_across_batch else [1] + crossent = math_ops.reduce_sum(crossent, axis=reduce_axis) + total_size = math_ops.reduce_sum(weights, axis=reduce_axis) + crossent = math_ops.div_no_nan(crossent, total_size) + elif sum_over_timesteps or sum_over_batch: + reduce_axis = [0] if sum_over_batch else [1] + crossent = math_ops.reduce_sum(crossent, axis=reduce_axis) + total_count = math_ops.cast( + math_ops.count_nonzero(weights, axis=reduce_axis), + dtype=crossent.dtype) + crossent = math_ops.div_no_nan(crossent, total_count) return crossent + + +class SequenceLoss(Loss): + """Weighted cross-entropy loss for a sequence of logits.""" + + def __init__(self, + average_across_timesteps=False, + average_across_batch=False, + sum_over_timesteps=True, + sum_over_batch=True, + softmax_loss_function=None, + name=None): + super(SequenceLoss, self).__init__(name=name) + self.average_across_timesteps = average_across_timesteps + self.average_across_batch = average_across_batch + self.sum_over_timesteps = sum_over_timesteps + self.sum_over_batch = sum_over_batch + self.softmax_loss_function = softmax_loss_function + + def __call__(self, y_true, y_pred, sample_weight=None): + """Override the parent __call__ to have a customized reduce behavior.""" + return sequence_loss(y_pred, y_true, sample_weight, + average_across_timesteps=self.average_across_timesteps, + average_across_batch=self.average_across_batch, + sum_over_timesteps=self.sum_over_timesteps, + sum_over_batch=self.sum_over_batch, + softmax_loss_function=self.softmax_loss_function, + name=self.name) + + def call(self, y_true, y_pred): + # Skip this method since the __call__ contains real implementation. + pass diff --git a/tensorflow/contrib/seq2seq/python/ops/sampler.py b/tensorflow/contrib/seq2seq/python/ops/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3e48b3bc61c0ff94ae0a1794767c7ff6914969 --- /dev/null +++ b/tensorflow/contrib/seq2seq/python/ops/sampler.py @@ -0,0 +1,765 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A library of sampler for use with SamplingDecoders.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.util import nest + +__all__ = [ + "Sampler", + "TrainingSampler", + "GreedyEmbeddingSampler", + "SampleEmbeddingSampler", + "CustomSampler", + "ScheduledEmbeddingTrainingSampler", + "ScheduledOutputTrainingSampler", + "InferenceSampler", +] + +_transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access + + +@six.add_metaclass(abc.ABCMeta) +class Sampler(object): + """Interface for implementing sampling in seq2seq decoders. + + Sampler instances are used by `BasicDecoder`. The normal usage of a sampler is + like below: + sampler = Sampler(init_args) + (initial_finished, initial_inputs) = sampler.initialize(input_tensors) + for time_step in range(time): + cell_output, cell_state = cell.call(cell_input, previous_state) + sample_ids = sampler.sample(time_step, cell_output, cell_state) + (finished, next_inputs, next_state) = sampler.next_inputs( + time_step,cell_output, cell_state) + + Note that all the tensor input should not be feed to Sampler as __init__() + parameters, instead, they should be feed by decoders via initialize(). + """ + + @abc.abstractmethod + def initialize(self, inputs, **kwargs): + """initialize the sampler with the input tensors. + + This method suppose to be only invoke once before the calling other methods + of the Sampler. + + Args: + inputs: A (structure of) input tensors, it could be a nested tuple or a + single tensor. + **kwargs: Other kwargs for initialization. It could contain tensors like + mask for inputs, or non tensor parameter. + + Returns: + `(initial_finished, initial_inputs)`. + """ + pass + + @abc.abstractmethod + def sample(self, time, outputs, state): + """Returns `sample_ids`.""" + pass + + @abc.abstractmethod + def next_inputs(self, time, outputs, state, sample_ids): + """Returns `(finished, next_inputs, next_state)`.""" + pass + + @abc.abstractproperty + def batch_size(self): + """Batch size of tensor returned by `sample`. + + Returns a scalar int32 tensor. The return value might not available before + the invocation of initialize(), in this case, ValueError is raised. + """ + raise NotImplementedError("batch_size has not been implemented") + + @abc.abstractproperty + def sample_ids_shape(self): + """Shape of tensor returned by `sample`, excluding the batch dimension. + + Returns a `TensorShape`. The return value might not available before the + invocation of initialize(). + """ + raise NotImplementedError("sample_ids_shape has not been implemented") + + @abc.abstractproperty + def sample_ids_dtype(self): + """DType of tensor returned by `sample`. + + Returns a DType. The return value might not available before the + invocation of initialize(). + """ + raise NotImplementedError("sample_ids_dtype has not been implemented") + + +class CustomSampler(Sampler): + """Base abstract class that allows the user to customize sampling.""" + + def __init__(self, + initialize_fn, + sample_fn, + next_inputs_fn, + sample_ids_shape=None, + sample_ids_dtype=None): + """Initializer. + + Args: + initialize_fn: callable that returns `(finished, next_inputs)` for the + first iteration. + sample_fn: callable that takes `(time, outputs, state)` and emits tensor + `sample_ids`. + next_inputs_fn: callable that takes `(time, outputs, state, sample_ids)` + and emits `(finished, next_inputs, next_state)`. + sample_ids_shape: Either a list of integers, or a 1-D Tensor of type + `int32`, the shape of each value in the `sample_ids` batch. Defaults to + a scalar. + sample_ids_dtype: The dtype of the `sample_ids` tensor. Defaults to int32. + """ + self._initialize_fn = initialize_fn + self._sample_fn = sample_fn + self._next_inputs_fn = next_inputs_fn + self._batch_size = None + self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or []) + self._sample_ids_dtype = sample_ids_dtype or dtypes.int32 + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return self._sample_ids_shape + + @property + def sample_ids_dtype(self): + return self._sample_ids_dtype + + def initialize(self, inputs, **kwargs): + (finished, next_inputs) = self._initialize_fn(inputs, **kwargs) + if self._batch_size is None: + self._batch_size = array_ops.size(finished) + return (finished, next_inputs) + + def sample(self, time, outputs, state): + return self._sample_fn(time=time, outputs=outputs, state=state) + + def next_inputs(self, time, outputs, state, sample_ids): + return self._next_inputs_fn( + time=time, outputs=outputs, state=state, sample_ids=sample_ids) + + +class TrainingSampler(Sampler): + """A Sampler for use during training. + + Only reads inputs. + + Returned sample_ids are the argmax of the RNN output logits. + """ + + def __init__(self, time_major=False): + """Initializer. + + Args: + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + + Raises: + ValueError: if `sequence_length` is not a 1D tensor. + """ + self.time_major = time_major + self._batch_size = None + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return tensor_shape.TensorShape([]) + + @property + def sample_ids_dtype(self): + return dtypes.int32 + + def initialize(self, inputs, sequence_length=None): + """Initialize the TrainSampler. + + Args: + inputs: A (structure of) input tensors. + sequence_length: An int32 vector tensor. + + Returns: + (finished, next_inputs), a tuple of two items. The first item is a boolean + vector to indicate whether the item in the batch has finished. The + second item is the first slide of input data based on the timestep + dimension (usually the second dim of the input). + """ + self.inputs = ops.convert_to_tensor(inputs, name="inputs") + if not self.time_major: + inputs = nest.map_structure(_transpose_batch_time, inputs) + + self.input_tas = nest.map_structure(_unstack_ta, inputs) + if sequence_length is None: + raise ValueError("sequence_length is required for TrainingSampler") + self.sequence_length = ops.convert_to_tensor( + sequence_length, name="sequence_length") + if self.sequence_length.get_shape().ndims != 1: + raise ValueError( + "Expected sequence_length to be a vector, but received shape: %s" % + self._sequence_length.get_shape()) + + self.zero_inputs = nest.map_structure( + lambda inp: array_ops.zeros_like(inp[0, :]), inputs) + + self._batch_size = array_ops.size(self.sequence_length) + + finished = math_ops.equal(0, self.sequence_length) + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond( + all_finished, + lambda: self.zero_inputs, + lambda: nest.map_structure(lambda inp: inp.read(0), self.input_tas)) + return (finished, next_inputs) + + def sample(self, time, outputs, state): + del state + sample_ids = math_ops.cast(math_ops.argmax(outputs, axis=-1), dtypes.int32) + return sample_ids + + def next_inputs(self, time, outputs, state, sample_ids): + del sample_ids + next_time = time + 1 + finished = (next_time >= self.sequence_length) + all_finished = math_ops.reduce_all(finished) + + def read_from_ta(inp): + return inp.read(next_time) + + next_inputs = control_flow_ops.cond( + all_finished, + lambda: self.zero_inputs, + lambda: nest.map_structure(read_from_ta, self.input_tas)) + return (finished, next_inputs, state) + + +class ScheduledEmbeddingTrainingSampler(TrainingSampler): + """A training sampler that adds scheduled sampling. + + Returns -1s for sample_ids where no sampling took place; valid sample id + values elsewhere. + """ + + def __init__(self, + sampling_probability, + embedding_fn=None, + time_major=False, + seed=None, + scheduling_seed=None): + """Initializer. + + Args: + sampling_probability: A `float32` 0-D or 1-D tensor: the probability of + sampling categorically from the output ids instead of reading directly + from the inputs. + embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + seed: The sampling seed. + scheduling_seed: The schedule decision rule sampling seed. + + Raises: + ValueError: if `sampling_probability` is not a scalar or vector. + """ + if callable(embedding_fn) or embedding_fn is None: + self.embedding_fn = embedding_fn + else: + raise ValueError("embedding_fn is expected to be callable, got %s" + % type(embedding_fn)) + self.sampling_probability = ops.convert_to_tensor( + sampling_probability, name="sampling_probability") + if self.sampling_probability.get_shape().ndims not in (0, 1): + raise ValueError( + "sampling_probability must be either a scalar or a vector. " + "saw shape: %s" % (self.sampling_probability.get_shape())) + self.seed = seed + self.scheduling_seed = scheduling_seed + super(ScheduledEmbeddingTrainingSampler, + self).__init__(time_major=time_major) + + def initialize(self, inputs, sequence_length=None, embedding=None): + if self.embedding_fn is None: + if embedding is None: + raise ValueError("embedding is required as a keyword argument for " + "ScheduledEmbeddingTrainingSampler") + self.embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + return super(ScheduledEmbeddingTrainingSampler, self).initialize( + inputs, sequence_length=sequence_length) + + def sample(self, time, outputs, state): + del state + # Return -1s where we did not sample, and sample_ids elsewhere + select_sample = bernoulli_sample( + probs=self.sampling_probability, + dtype=dtypes.bool, + sample_shape=self.batch_size, + seed=self.scheduling_seed) + return array_ops.where(select_sample, + categorical_sample(logits=outputs, seed=self.seed), + gen_array_ops.fill([self.batch_size], -1)) + + def next_inputs(self, time, outputs, state, sample_ids): + (finished, base_next_inputs, state) = ( + super(ScheduledEmbeddingTrainingSampler, self).next_inputs( + time=time, outputs=outputs, state=state, sample_ids=sample_ids)) + + def maybe_sample(): + """Perform scheduled sampling.""" + where_sampling = math_ops.cast( + array_ops.where(sample_ids > -1), dtypes.int32) + where_not_sampling = math_ops.cast( + array_ops.where(sample_ids <= -1), dtypes.int32) + sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling) + inputs_not_sampling = array_ops.gather_nd(base_next_inputs, + where_not_sampling) + sampled_next_inputs = self.embedding_fn(sample_ids_sampling) + base_shape = array_ops.shape(base_next_inputs) + return (array_ops.scatter_nd( + indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) + + array_ops.scatter_nd( + indices=where_not_sampling, + updates=inputs_not_sampling, + shape=base_shape)) + + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond(all_finished, lambda: base_next_inputs, + maybe_sample) + return (finished, next_inputs, state) + + +class ScheduledOutputTrainingSampler(TrainingSampler): + """A training sampler that adds scheduled sampling directly to outputs. + + Returns False for sample_ids where no sampling took place; True elsewhere. + """ + + def __init__(self, + sampling_probability, + time_major=False, + seed=None, + next_inputs_fn=None): + """Initializer. + + Args: + sampling_probability: A `float32` scalar tensor: the probability of + sampling from the outputs instead of reading directly from the inputs. + time_major: Python bool. Whether the tensors in `inputs` are time major. + If `False` (default), they are assumed to be batch major. + seed: The sampling seed. + next_inputs_fn: (Optional) callable to apply to the RNN outputs to create + the next input when sampling. If `None` (default), the RNN outputs will + be used as the next inputs. + + Raises: + ValueError: if `sampling_probability` is not a scalar or vector. + """ + self.sampling_probability = ops.convert_to_tensor( + sampling_probability, name="sampling_probability") + if self.sampling_probability.get_shape().ndims not in (0, 1): + raise ValueError( + "sampling_probability must be either a scalar or a vector. " + "saw shape: %s" % (self._sampling_probability.get_shape())) + + self.seed = seed + self.next_inputs_fn = next_inputs_fn + + super(ScheduledOutputTrainingSampler, self).__init__(time_major=time_major) + + def initialize(self, inputs, sequence_length=None, auxiliary_inputs=None): + if auxiliary_inputs is None: + maybe_concatenated_inputs = inputs + else: + inputs = ops.convert_to_tensor(inputs) + auxiliary_inputs = ops.convert_to_tensor(auxiliary_inputs) + maybe_concatenated_inputs = nest.map_structure( + lambda x, y: array_ops.concat((x, y), -1), inputs, auxiliary_inputs) + if not self.time_major: + auxiliary_inputs = nest.map_structure(_transpose_batch_time, + auxiliary_inputs) + if auxiliary_inputs is not None: + self._auxiliary_input_tas = nest.map_structure(_unstack_ta, + auxiliary_inputs) + else: + self._auxiliary_input_tas = None + + return super(ScheduledOutputTrainingSampler, self).initialize( + maybe_concatenated_inputs, sequence_length=sequence_length) + + def sample(self, time, outputs, state): + del state + return bernoulli_sample( + probs=self.sampling_probability, + sample_shape=self.batch_size, + seed=self.seed) + + def next_inputs(self, time, outputs, state, sample_ids): + (finished, base_next_inputs, state) = ( + super(ScheduledOutputTrainingSampler, self).next_inputs( + time=time, outputs=outputs, state=state, sample_ids=sample_ids)) + sample_ids = math_ops.cast(sample_ids, dtypes.bool) + + def maybe_sample(): + """Perform scheduled sampling.""" + + def maybe_concatenate_auxiliary_inputs(outputs_, indices=None): + """Concatenate outputs with auxiliary inputs, if they exist.""" + if self._auxiliary_input_tas is None: + return outputs_ + + next_time = time + 1 + auxiliary_inputs = nest.map_structure(lambda ta: ta.read(next_time), + self._auxiliary_input_tas) + if indices is not None: + auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices) + return nest.map_structure(lambda x, y: array_ops.concat((x, y), -1), + outputs_, auxiliary_inputs) + + if self.next_inputs_fn is None: + return array_ops.where(sample_ids, + maybe_concatenate_auxiliary_inputs(outputs), + base_next_inputs) + + where_sampling = math_ops.cast(array_ops.where(sample_ids), dtypes.int32) + where_not_sampling = math_ops.cast( + array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32) + outputs_sampling = array_ops.gather_nd(outputs, where_sampling) + inputs_not_sampling = array_ops.gather_nd(base_next_inputs, + where_not_sampling) + sampled_next_inputs = maybe_concatenate_auxiliary_inputs( + self.next_inputs_fn(outputs_sampling), where_sampling) + + base_shape = array_ops.shape(base_next_inputs) + return (array_ops.scatter_nd( + indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) + + array_ops.scatter_nd( + indices=where_not_sampling, + updates=inputs_not_sampling, + shape=base_shape)) + + all_finished = math_ops.reduce_all(finished) + no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids)) + next_inputs = control_flow_ops.cond( + math_ops.logical_or(all_finished, no_samples), lambda: base_next_inputs, + maybe_sample) + return (finished, next_inputs, state) + + +class GreedyEmbeddingSampler(Sampler): + """A sampler for use during inference. + + Uses the argmax of the output (treated as logits) and passes the + result through an embedding layer to get the next input. + """ + + def __init__(self, embedding_fn=None): + """Initializer. + + Args: + embedding_fn: A optional callable that takes a vector tensor of `ids` + (argmax ids), or the `params` argument for `embedding_lookup`. The + returned tensor will be passed to the decoder input. Default to use + `embedding_ops.embedding_lookup`. + """ + if embedding_fn is None or callable(embedding_fn): + self.embedding_fn = embedding_fn + else: + raise ValueError("embedding_fn is expected to be a callable, got %s" % + type(embedding_fn)) + self._batch_size = None + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return tensor_shape.TensorShape([]) + + @property + def sample_ids_dtype(self): + return dtypes.int32 + + def initialize(self, embedding, start_tokens=None, end_token=None): + """Initialize the GreedyEmbeddingSampler. + + Args: + embedding: tensor that contains embedding states matrix. It will be used + to generate generate outputs with start_tokens and end_tokens. The + embedding will be ignored if the embedding_fn has been provided at + __init__(). + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + + Returns: + Tuple of two items: `(finished, self.start_inputs)`. + Raises: + ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a + scalar. + """ + if self.embedding_fn is None: + self.embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self.start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + self.end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self.start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._batch_size = array_ops.size(start_tokens) + if self.end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + self.start_inputs = self.embedding_fn(self.start_tokens) + + finished = array_ops.tile([False], [self._batch_size]) + return (finished, self.start_inputs) + + def sample(self, time, outputs, state): + """sample for GreedyEmbeddingHelper.""" + del time, state # unused by sample_fn + # Outputs are logits, use argmax to get the most probable id + if not isinstance(outputs, ops.Tensor): + raise TypeError( + "Expected outputs to be a single Tensor, got: %s" % type(outputs)) + sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32) + return sample_ids + + def next_inputs(self, time, outputs, state, sample_ids): + """next_inputs_fn for GreedyEmbeddingHelper.""" + del time, outputs # unused by next_inputs_fn + finished = math_ops.equal(sample_ids, self.end_token) + all_finished = math_ops.reduce_all(finished) + next_inputs = control_flow_ops.cond( + all_finished, + # If we're finished, the next_inputs value doesn't matter + lambda: self.start_inputs, + lambda: self.embedding_fn(sample_ids)) + return (finished, next_inputs, state) + + +class SampleEmbeddingSampler(GreedyEmbeddingSampler): + """A sampler for use during inference. + + Uses sampling (from a distribution) instead of argmax and passes the + result through an embedding layer to get the next input. + """ + + def __init__(self, embedding_fn=None, softmax_temperature=None, seed=None): + """Initializer. + + Args: + embedding_fn: (Optional) A callable that takes a vector tensor of `ids` + (argmax ids), or the `params` argument for `embedding_lookup`. The + returned tensor will be passed to the decoder input. + softmax_temperature: (Optional) `float32` scalar, value to divide the + logits by before computing the softmax. Larger values (above 1.0) result + in more random samples, while smaller values push the sampling + distribution towards the argmax. Must be strictly greater than 0. + Defaults to 1.0. + seed: (Optional) The sampling seed. + + Raises: + ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a + scalar. + """ + super(SampleEmbeddingSampler, self).__init__(embedding_fn) + self.softmax_temperature = softmax_temperature + self.seed = seed + + def sample(self, time, outputs, state): + """sample for SampleEmbeddingHelper.""" + del time, state # unused by sample_fn + # Outputs are logits, we sample instead of argmax (greedy). + if not isinstance(outputs, ops.Tensor): + raise TypeError( + "Expected outputs to be a single Tensor, got: %s" % type(outputs)) + if self.softmax_temperature is None: + logits = outputs + else: + logits = outputs / self.softmax_temperature + + return categorical_sample(logits=logits, seed=self.seed) + + +class InferenceSampler(Sampler): + """A helper to use during inference with a custom sampling function.""" + + def __init__(self, + sample_fn, + sample_shape, + sample_dtype, + end_fn, + next_inputs_fn=None): + """Initializer. + + Args: + sample_fn: A callable that takes `outputs` and emits tensor `sample_ids`. + sample_shape: Either a list of integers, or a 1-D Tensor of type `int32`, + the shape of the each sample in the batch returned by `sample_fn`. + sample_dtype: the dtype of the sample returned by `sample_fn`. + end_fn: A callable that takes `sample_ids` and emits a `bool` vector + shaped `[batch_size]` indicating whether each sample is an end token. + next_inputs_fn: (Optional) A callable that takes `sample_ids` and returns + the next batch of inputs. If not provided, `sample_ids` is used as the + next batch of inputs. + """ + self.sample_fn = sample_fn + self.sample_shape = tensor_shape.TensorShape(sample_shape) + self.sample_dtype = sample_dtype + self.end_fn = end_fn + self.next_inputs_fn = next_inputs_fn + self._batch_size = None + + @property + def batch_size(self): + if self._batch_size is None: + raise ValueError("batch_size accessed before initialize was called") + return self._batch_size + + @property + def sample_ids_shape(self): + return self.sample_shape + + @property + def sample_ids_dtype(self): + return self.sample_dtype + + def initialize(self, start_inputs): + self.start_inputs = ops.convert_to_tensor(start_inputs, name="start_inputs") + self._batch_size = array_ops.shape(start_inputs)[0] + finished = array_ops.tile([False], [self._batch_size]) + return (finished, self.start_inputs) + + def sample(self, time, outputs, state): + del time, state # unused by sample + return self.sample_fn(outputs) + + def next_inputs(self, time, outputs, state, sample_ids): + del time, outputs # unused by next_inputs + if self.next_inputs_fn is None: + next_inputs = sample_ids + else: + next_inputs = self.next_inputs_fn(sample_ids) + finished = self.end_fn(sample_ids) + return (finished, next_inputs, state) + + +# The following sample functions (_call_sampler, bernoulli_sample, +# categorical_sample) mimic TensorFlow Probability distribution semantics. +def _call_sampler(sample_n_fn, sample_shape, name=None): + """Reshapes vector of samples.""" + with ops.name_scope(name, "call_sampler", values=[sample_shape]): + sample_shape = ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32, name="sample_shape") + # Ensure sample_shape is a vector (vs just a scalar). + pad = math_ops.cast( + math_ops.equal(array_ops.rank(sample_shape), 0), dtypes.int32) + sample_shape = array_ops.reshape( + sample_shape, + array_ops.pad( + array_ops.shape(sample_shape), + paddings=[[pad, 0]], + constant_values=1)) + samples = sample_n_fn(math_ops.reduce_prod(sample_shape)) + batch_event_shape = array_ops.shape(samples)[1:] + final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) + return array_ops.reshape(samples, final_shape) + + +def bernoulli_sample(probs=None, + logits=None, + dtype=dtypes.int32, + sample_shape=(), + seed=None): + """Samples from Bernoulli distribution.""" + if probs is None: + probs = math_ops.sigmoid(logits, name="probs") + else: + probs = ops.convert_to_tensor(probs, name="probs") + batch_shape_tensor = array_ops.shape(probs) + + def _sample_n(n): + """Sample vector of Bernoullis.""" + new_shape = array_ops.concat([[n], batch_shape_tensor], 0) + uniform = random_ops.random_uniform(new_shape, seed=seed, dtype=probs.dtype) + return math_ops.cast(math_ops.less(uniform, probs), dtype) + + return _call_sampler(_sample_n, sample_shape) + + +def categorical_sample(logits, dtype=dtypes.int32, sample_shape=(), seed=None): + """Samples from categorical distribution.""" + logits = ops.convert_to_tensor(logits, name="logits") + event_size = array_ops.shape(logits)[-1] + batch_shape_tensor = array_ops.shape(logits)[:-1] + + def _sample_n(n): + """Sample vector of categoricals.""" + if logits.shape.ndims == 2: + logits_2d = logits + else: + logits_2d = array_ops.reshape(logits, [-1, event_size]) + sample_dtype = dtypes.int64 if logits.dtype.size > 4 else dtypes.int32 + draws = random_ops.multinomial( + logits_2d, n, seed=seed, output_dtype=sample_dtype) + draws = array_ops.reshape( + array_ops.transpose(draws), + array_ops.concat([[n], batch_shape_tensor], 0)) + return math_ops.cast(draws, dtype) + + return _call_sampler(_sample_n, sample_shape) + + +def _unstack_ta(inp): + return tensor_array_ops.TensorArray( + dtype=inp.dtype, + size=array_ops.shape(inp)[0], + element_shape=inp.get_shape()[1:]).unstack(inp) diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py index 08983337fccc138d40eb959cecc5bf9e47cf6cac..f3efd292cf5acba4319c8a5545a7f70fae4b5ce1 100644 --- a/tensorflow/contrib/session_bundle/exporter.py +++ b/tensorflow/contrib/session_bundle/exporter.py @@ -304,10 +304,10 @@ class Exporter(object): def parser(path): if os.name == "nt": match = re.match( - "^" + export_dir_base.replace("\\", "/") + "/(\\d{8})$", + r"^" + export_dir_base.replace("\\", "/") + r"/(\d{8})$", path.path.replace("\\", "/")) else: - match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path) + match = re.match(r"^" + export_dir_base + r"/(\d{8})$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) diff --git a/tensorflow/contrib/session_bundle/gc_test.py b/tensorflow/contrib/session_bundle/gc_test.py index 8faf3ef3d4cd7ee0096265283070e25d06782254..02725bb1cbb4ef9ace29dcc58f6d23fb241d96b2 100644 --- a/tensorflow/contrib/session_bundle/gc_test.py +++ b/tensorflow/contrib/session_bundle/gc_test.py @@ -104,7 +104,7 @@ class GcTest(test_util.TensorFlowTestCase): # create a simple parser that pulls the export_version from the directory. def parser(path): - match = re.match("^" + base_dir + "/(\\d+)$", path.path) + match = re.match(r"^" + base_dir + r"/(\d+)$", path.path) if not match: return None return path._replace(export_version=int(match.group(1))) diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index 1b2b6acacca838f95cb758ae88f79263993ca69e..c63a3ca19b6a70cf7776c7fce4e0291ee94b775c 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -32,7 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import map_fn from tensorflow.python.ops import image_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops @@ -396,8 +396,8 @@ class Image(ItemHandler): image_format = keys_to_tensors[self._format_key] if self._repeated: - return functional_ops.map_fn(lambda x: self._decode(x, image_format), - image_buffer, dtype=self._dtype) + return map_fn.map_fn(lambda x: self._decode(x, image_format), + image_buffer, dtype=self._dtype) else: return self._decode(image_buffer, image_format) diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD index 8bbdf96384683c68648367c6433eeb89c64c22bf..e9595d1b324dbd3d570d2407a6620c5295b15548 100644 --- a/tensorflow/contrib/slim/python/slim/nets/BUILD +++ b/tensorflow/contrib/slim/python/slim/nets/BUILD @@ -115,9 +115,9 @@ py_library( py_test( name = "inception_v1_test", - size = "large", + size = "medium", srcs = ["inception_v1_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", deps = [ ":inception_v1", @@ -135,9 +135,9 @@ py_test( py_test( name = "inception_v2_test", - size = "large", + size = "medium", srcs = ["inception_v2_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", deps = [ ":inception_v2", @@ -155,9 +155,9 @@ py_test( py_test( name = "inception_v3_test", - size = "large", + size = "medium", srcs = ["inception_v3_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", deps = [ ":inception_v3", @@ -233,8 +233,9 @@ py_library( py_test( name = "resnet_v1_test", - size = "large", + size = "medium", srcs = ["resnet_v1_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":resnet_utils", @@ -268,8 +269,9 @@ py_library( py_test( name = "resnet_v2_test", - size = "large", + size = "medium", srcs = ["resnet_v2_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":resnet_utils", diff --git a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py index 8fcd7aeef6a6964902666a4f3c17e05b0c7b52ee..f31bdbd399c9de4f2f5d557b75b1ece6d64a765e 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import lanczos from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -80,7 +81,8 @@ if __name__ == "__main__": for shape in [[4, 4], [7, 4], [5, 8]]: for orthogonalize in True, False: for steps in range(1, min(shape) + 1): - for use_static_shape in True, False: + # TF2 does not support placeholders so we skip it + for use_static_shape in set([True, tf2.enabled()]): arg_string = "%s_%s_%s_%s_staticshape_%s" % ( dtype.__name__, "_".join(map(str, shape)), orthogonalize, steps, use_static_shape) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py index 2a9100903aae5689919a6b25fcb18ff192f250b3..841a41a2339824ab8ca15f4bdd74be697cd6fe9f 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import least_squares from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -76,7 +77,8 @@ def _get_least_squares_tests(dtype_, use_static_shape_, shape_): if __name__ == "__main__": for dtype in np.float32, np.float64: for shape in [[4, 4], [8, 5], [3, 7]]: - for use_static_shape in True, False: + # TF2 does not support placeholders under eager so we skip it + for use_static_shape in set([True, tf2.enabled()]): arg_string = "%s_%s_staticshape_%s" % (dtype.__name__, "_".join(map(str, shape)), use_static_shape) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py index a0e6eb87bc06fb1303a7eb86fa6760458f20a9b9..10807f7a80617e56abeb6d13ce419a49a2269aac 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import linear_equations from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -113,7 +114,8 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_): if __name__ == "__main__": for dtype in np.float32, np.float64: for size in 1, 4, 10: - for use_static_shape in True, False: + # TF2 does not support placeholders under eager so we skip it + for use_static_shape in set([True, tf2.enabled()]): shape = [size, size] arg_string = "%s_%s_staticshape_%s" % (dtype.__name__, size, use_static_shape) diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD index d7ba754f701d4b433e35ad8396eae7ee6132b97f..ed4eca1a60a6f0ccf629d8aa7906c02092e25ba0 100644 --- a/tensorflow/contrib/sparsemax/BUILD +++ b/tensorflow/contrib/sparsemax/BUILD @@ -49,6 +49,9 @@ cuda_py_tests( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = [ + "oss_serial", + ], ) cuda_py_tests( @@ -64,4 +67,7 @@ cuda_py_tests( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = [ + "oss_serial", + ], ) diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index d8236a0a6fa6d0d0e383e454eb0146bb10b6f49d..0d87cea9fbaa8fe28b55ec996414a568d39efee3 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -50,9 +50,10 @@ def _accuracy(predictions, targets, weights=None): def _r2(probabilities, targets, weights=None): targets = math_ops.to_float(targets) y_mean = math_ops.reduce_mean(targets, 0) - squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0) + squares_total = math_ops.reduce_sum( + math_ops.squared_difference(targets, y_mean), 0) squares_residuals = math_ops.reduce_sum( - math_ops.square(targets - probabilities), 0) + math_ops.squared_difference(targets, probabilities), 0) score = 1 - math_ops.reduce_sum(squares_residuals / squares_total) return metrics.mean(score, weights=weights) diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h index e04eb60f9b27cfd8b6b4e1502594d4d310ae55cc..774da472f1543f938d1b607ebdef008f7b540211 100644 --- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.h +++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.h @@ -18,10 +18,10 @@ #include #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" +#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/random/distribution_sampler.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/lib/strings/strcat.h" diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h index d3edb43733761a906c6e5bf8b65f76e3e1ae56fc..3100a5a0e5da1103b61bd089cd433721686b9e72 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h @@ -32,7 +32,7 @@ class DecisionTreeResource : public ResourceBase { // Constructor. explicit DecisionTreeResource(const TensorForestParams& params); - string DebugString() override { + string DebugString() const override { return strings::StrCat("DecisionTree[size=", decision_tree_->decision_tree().nodes_size(), "]"); } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h index eea0be27caf0a022ba7acaacd359c75a2df4eedb..44f2b3f473b9eced06bd800b9cf0a5a0825ec3eb 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h @@ -40,7 +40,7 @@ class FertileStatsResource : public ResourceBase { model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_); } - string DebugString() override { return "FertileStats"; } + string DebugString() const override { return "FertileStats"; } void ExtractFromProto(const FertileStats& stats); diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 784acce444a8d0c066f1b7ae6c1b5d7d65405549..91b6d2614a8963c21e35c385411dc4c9956e3146 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -11,567 +11,54 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "tf_cc_test", - "tf_copts", "tf_cuda_library", - "tf_custom_op_library", "tf_custom_op_library_additional_deps", - "tf_gen_op_libs", - "tf_gen_op_wrapper_py", ) -load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "cuda_py_tests") -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", ) -exports_files(glob([ - "test/testdata/*", -])) - -tf_cuda_cc_test( - name = "tensorrt_test_cc", - size = "small", - srcs = ["tensorrt_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - "//tensorflow/core:gpu_init", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_custom_op_library( - name = "python/ops/_trt_engine_op.so", - srcs = [ - "ops/trt_engine_op.cc", - ], - deps = [ - ":trt_shape_function", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - tf_cuda_library( name = "trt_shape_function", srcs = ["shape_fn/trt_shfn.cc"], hdrs = ["shape_fn/trt_shfn.h"], visibility = ["//visibility:public"], deps = [ - ":trt_logging", - ":trt_plugins", + "//tensorflow/compiler/tf2tensorrt:trt_logging", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + "@local_config_tensorrt//:tensorrt", ]) + tf_custom_op_library_additional_deps(), ) -cc_library( - name = "trt_engine_op_kernel", - srcs = [ - "kernels/trt_engine_op.cc", - ], - hdrs = [ - "kernels/trt_engine_op.h", - ], - copts = tf_copts(), - visibility = ["//visibility:public"], - deps = [ - ":test_utils", - ":trt_allocator", - ":trt_conversion", - ":trt_logging", - ":trt_plugins", - ":trt_resources", - ":utils", - "//tensorflow/core:gpu_headers_lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:stream_executor_headers_lib", - "//tensorflow/core/grappler/costs:graph_properties", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]) + tf_custom_op_library_additional_deps(), - # TODO(laigd): fix this by merging header file in cc file. - alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs -) - -tf_gen_op_libs( - op_lib_names = [ - "trt_engine_op", - ], -) - -tf_cuda_library( - name = "trt_logging", - srcs = ["log/trt_logger.cc"], - hdrs = ["log/trt_logger.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_gen_op_wrapper_py( - name = "trt_engine_op", - deps = [ - ":trt_engine_op_op_lib", - ":trt_logging", - ":trt_shape_function", - ], -) - -tf_custom_op_py_library( - name = "trt_engine_op_loader", - srcs = ["python/ops/trt_engine_op.py"], - dso = [ - ":python/ops/_trt_engine_op.so", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), - kernels = [ - ":trt_engine_op_kernel", - ":trt_engine_op_op_lib", - ":trt_shape_function", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:resources", - ], -) - py_library( name = "init_py", srcs = [ "__init__.py", "python/__init__.py", + "python/trt_convert.py", ], srcs_version = "PY2AND3", deps = [ - ":tf_trt_integration_test_base", - ":trt_convert_py", - ":trt_ops_py", - "//tensorflow/python:errors", - ], -) - -py_library( - name = "trt_ops_py", - srcs_version = "PY2AND3", - deps = [ - ":trt_engine_op", - ":trt_engine_op_loader", - ], -) - -py_library( - name = "trt_convert_py", - srcs = ["python/trt_convert.py"], - srcs_version = "PY2AND3", - deps = [ - ":wrap_conversion", - "//tensorflow/python:graph_util", - "//tensorflow/python:session", - "//tensorflow/python:tf_optimizer", - "//tensorflow/python/saved_model:builder", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:tag_constants", - ], -) - -# TODO(aaroey): this wrapper has been causing troubles of double linking, so -# either get rid of it, or split to make it contain minimum dependencies. -tf_py_wrap_cc( - name = "wrap_conversion", - srcs = ["trt_conversion.i"], - copts = tf_copts(), - swig_includes = [ - "//tensorflow/python:platform/base.i", - ], - deps = [ - ":test_utils", - ":trt_conversion", - ":trt_engine_op_kernel", - "//third_party/python_runtime:headers", - ], -) - -tf_cuda_library( - name = "trt_resources", - srcs = [ - "resources/trt_int8_calibrator.cc", - "resources/trt_resource_manager.cc", - ], - hdrs = [ - "resources/trt_int8_calibrator.h", - "resources/trt_resource_manager.h", - "resources/trt_resources.h", + "//tensorflow/python/compiler/tensorrt:init_py", ], - deps = [ - ":trt_allocator", - ":trt_logging", - ":utils", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), ) -tf_cuda_library( - name = "trt_allocator", - srcs = ["resources/trt_allocator.cc"], - hdrs = ["resources/trt_allocator.h"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) +# The following rules forward the libraries that were moved in order to not +# break other internal targets. -tf_cc_test( - name = "trt_allocator_test", - size = "small", - srcs = ["resources/trt_allocator_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - ":trt_allocator", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -# Library for the node-level conversion portion of TensorRT operation creation -tf_cuda_library( +alias( name = "trt_conversion", - srcs = [ - "convert/convert_graph.cc", - "convert/convert_nodes.cc", - "convert/trt_optimization_pass.cc", - ], - hdrs = [ - "convert/convert_graph.h", - "convert/convert_nodes.h", - "convert/trt_optimization_pass.h", - ], - deps = [ - ":segment", - ":test_utils", - ":trt_allocator", - ":trt_plugins", - ":trt_logging", - ":trt_resources", - ":utils", - "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:utils", - "//tensorflow/core:framework", - "//tensorflow/core:framework_lite", - "//tensorflow/core:gpu_runtime", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:devices", - "//tensorflow/core/grappler/clusters:virtual_cluster", - "//tensorflow/core/grappler/costs:graph_properties", - "//tensorflow/core/grappler/optimizers:meta_optimizer", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]) + tf_custom_op_library_additional_deps(), -) - -tf_cuda_cc_test( - name = "convert_graph_test", - size = "medium", - srcs = ["convert/convert_graph_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_conversion", - "@com_google_googletest//:gtest", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:direct_session", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_cuda_cc_test( - name = "convert_nodes_test", - size = "medium", - srcs = ["convert/convert_nodes_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_logging", - ":trt_conversion", - ":trt_plugins", - "@com_google_googletest//:gtest", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core/grappler/costs:graph_properties", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:nv_infer", - ]), -) - -# Library for the segmenting portion of TensorRT operation creation -cc_library( - name = "segment", - srcs = ["segment/segment.cc"], - hdrs = [ - "segment/segment.h", - "segment/union_find.h", - ], - deps = [ - "//tensorflow/core:graph", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:protos_all_cc", - "@protobuf_archive//:protobuf_headers", - ], -) - -tf_cc_test( - name = "segment_test", - size = "small", - srcs = ["segment/segment_test.cc"], - tags = [ - "no_windows", - "nomac", - ], - deps = [ - ":segment", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core:core_cpu", - "//tensorflow/core:lib", - "//tensorflow/core:ops", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - -# Library for the plugin factory -tf_cuda_library( - name = "trt_plugins", - srcs = [ - "plugin/trt_plugin.cc", - "plugin/trt_plugin_factory.cc", - "plugin/trt_plugin_utils.cc", - ], - hdrs = [ - "plugin/trt_plugin.h", - "plugin/trt_plugin_factory.h", - "plugin/trt_plugin_utils.h", - ], - deps = [ - "//tensorflow/core:framework_lite", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) - -tf_cuda_cc_test( - name = "trt_plugin_factory_test", - size = "small", - srcs = ["plugin/trt_plugin_factory_test.cc"], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], - deps = [ - ":trt_plugins", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:nv_infer", - ]), -) - -py_library( - name = "tf_trt_integration_test_base", - srcs = ["test/tf_trt_integration_test_base.py"], - deps = [ - ":trt_convert_py", - ":trt_ops_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - ], -) - -cuda_py_test( - name = "trt_convert_test", - srcs = ["python/trt_convert_test.py"], - additional_deps = [ - ":trt_convert_py", - ":trt_ops_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:graph_util", - "//tensorflow/python/saved_model:builder", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:signature_def_utils", - "//tensorflow/python/saved_model:tag_constants", - "//tensorflow/python/saved_model:utils", - "//tensorflow/python/tools:freeze_graph_lib", - "//tensorflow/python/tools:saved_model_utils", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], -) - -cuda_py_tests( - name = "tf_trt_integration_test", - srcs = [ - "test/base_test.py", - "test/batch_matmul_test.py", - "test/biasadd_matmul_test.py", - "test/binary_tensor_weight_broadcast_test.py", - "test/concatenation_test.py", - "test/const_broadcast_test.py", - "test/manual_test.py", - "test/memory_alignment_test.py", - "test/multi_connection_neighbor_engine_test.py", - "test/neighboring_engine_test.py", - "test/quantization_test.py", - "test/rank_two_test.py", - "test/reshape_transpose_test.py", - "test/vgg_block_nchw_test.py", - "test/vgg_block_test.py", - ], - additional_deps = [ - ":tf_trt_integration_test_base", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_windows", - "nomac", - ], -) - -cuda_py_tests( - name = "tf_trt_integration_test_no_oss", - srcs = [ - "test/unary_test.py", - ], - additional_deps = [ - ":tf_trt_integration_test_base", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_oss", # TODO(b/117274186): re-enable in OSS after crash fixed - "no_pip", # TODO(b/117274186): re-enable in OSS after crash fixed - "no_windows", - "nomac", - ], + actual = "//tensorflow/compiler/tf2tensorrt:trt_conversion", ) -cuda_py_test( - name = "quantization_mnist_test", - srcs = ["test/quantization_mnist_test.py"], - additional_deps = [ - ":tf_trt_integration_test_base", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python/keras:keras", - "//tensorflow/python/estimator:estimator", - ], - data = [ - "test/testdata/checkpoint", - "test/testdata/model.ckpt-46900.data-00000-of-00001", - "test/testdata/model.ckpt-46900.index", - ], - tags = [ - "no_cuda_on_cpu_tap", - "no_pip", - "no_tap", # It is not able to download the mnist data. - "no_windows", - "nomac", - ], +alias( + name = "trt_op_kernels", + actual = "//tensorflow/compiler/tf2tensorrt:trt_op_kernels", ) -cc_library( - name = "utils", - srcs = ["convert/utils.cc"], - hdrs = ["convert/utils.h"], - copts = tf_copts(), - deps = [ - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "test_utils", - srcs = ["test/utils.cc"], - hdrs = ["test/utils.h"], - deps = [ - "//tensorflow/core:lib", - "@com_googlesource_code_re2//:re2", - ], +alias( + name = "trt_engine_op_op_lib", + actual = "//tensorflow/compiler/tf2tensorrt:trt_engine_op_op_lib", ) diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md deleted file mode 100644 index caf8b6db0dc0a220d593f9c0afc9464ca51a1e05..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/README.md +++ /dev/null @@ -1,29 +0,0 @@ -# Using TensorRT in TensorFlow - -This module provides necessary bindings and introduces TRT_engine_op operator -that wraps a subgraph in TensorRT. This is still a work in progress but should -be useable with most common graphs. - -## Compilation - -In order to compile the module, you need to have a local TensorRT installation -(libnvinfer.so and respective include files). During the configuration step, -TensorRT should be enabled and installation path should be set. If installed -through package managers (deb,rpm), configure script should find the necessary -components from the system automatically. If installed from tar packages, user -has to set path to location where the library is installed during configuration. - -```shell -bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_package -bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/ -``` - -After the installation of tensorflow package, TensorRT transformation will be -available. An example use can be found in test/test_tftrt.py script - -## Installing TensorRT 3.0.4 - -In order to make use of TensorRT integration, you will need a local installation -of TensorRT 3.0.4 from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt). -Installation instructions for compatibility with TensorFlow are provided on the -[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide. diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py index 140ad4828208ae4844a49bf664955b50cd9e51cd..fd551d70b4385b14b84b7b98a6d16b0c03733d38 100644 --- a/tensorflow/contrib/tensorrt/__init__.py +++ b/tensorflow/contrib/tensorrt/__init__.py @@ -18,18 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import errors - -# pylint: disable=unused-import,wildcard-import,g-import-not-at-top -try: - from tensorflow.contrib.tensorrt.python import * -except errors.NotFoundError as e: - no_trt_message = ( - '**** Failed to initialize TensorRT. This is either because the TensorRT' - ' installation path is not in LD_LIBRARY_PATH, or because you do not have' - ' it installed. If not installed, please go to' - ' https://developer.nvidia.com/tensorrt to download and install' - ' TensorRT ****') - print(no_trt_message) - raise e -# pylint: enable=unused-import,wildcard-import,g-import-not-at-top +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.tensorrt.python import * +# pylint: enable=unused-import,wildcard-import diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 69058c5826822c519a69d50860c06b8ab3ec6578..0a2cf105baf5efb62d0c535c1f2d081973ec0ea3 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -45,10 +45,10 @@ tf_custom_op_library( "inc_op_kernel.cu.cc", ], deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", "//tensorflow/core:framework_lite", ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + "@local_config_tensorrt//:tensorrt", ]), ) @@ -64,10 +64,10 @@ tf_kernel_library( "inc_op_kernel.cu.cc", ], deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", + "//tensorflow/compiler/tf2tensorrt:trt_plugins", "//tensorflow/core:stream_executor_headers_lib", ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + "@local_config_tensorrt//:tensorrt", ]) + tf_custom_op_library_additional_deps(), ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 8d4c893af56689185da72398919e2241d451594b..7c9075142a02546ddd580e861ac87cb86badd739 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 189e9c939b9ffd4450f7ba95fe1abdbbc049b430..fb048d7b19da0f010ed918b147013b20d37ed0dd 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h deleted file mode 100644 index b545f497f32d5a1a6960b748467ca189b7debf6c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ /dev/null @@ -1,145 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ - -#include -#include - -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/mutex.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { -struct TRTInt8Calibrator; -class TRTCalibrationResource; -class AsyncHelper; -// TODO(Sami): Remove this file? - -// This OP can construct TRTEngine on the fly and if construction of engine -// fails, executes equivalent subgraph as a TensorFlow function. -class TRTEngineOp : public AsyncOpKernel { - public: - explicit TRTEngineOp(OpKernelConstruction* context); - - void ComputeAsync(OpKernelContext* context, - AsyncOpKernel::DoneCallback done) override; - ~TRTEngineOp(); - - private: - // Execute calibration - void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); - - // Construct a function handle for executing native funcdef graph - Status ConstructFunctionHandle(OpKernelContext* ctx); - - // Execute replaced native segment as function Op. - void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); - - // Execute the tensorrt engine. Returns whether we need to retry by running - // the native segment. - bool ExecuteTrtEngine(OpKernelContext* ctx, const int num_batch, - nvinfer1::ICudaEngine* trt_engine_ptr, - nvinfer1::IExecutionContext* trt_execution_context_ptr); - - // Allocate necessary resources for calibration - Status AllocateCalibrationResources(OpKernelContext* ctx, - TRTCalibrationResource** cr); - - // TODO(samikama): context should go to a resource manager! - typedef std::pair, - TrtUniquePtrType> - EngineCtxPair; - EngineCtxPair& GetEngine(int batch_size, OpKernelContext* ctx); - - // Return engine batch closest to input batch. - int GetEngineBatch(OpKernelContext* ctx); - - nvinfer1::IGpuAllocator* GetAllocator(OpKernelContext* ctx); - - // map to keep engines and their execution context for given batch size. - std::unordered_map engine_map_; - std::vector input_nodes_; - std::vector output_nodes_; - - // keep device allocator for TRT. - std::unique_ptr allocator_; - - // serialized protobuf segment or trt engine depending on static_engine_ flag. - string serialized_segment_; - - // Name of the function for TF native execution of the segment. - string funcdef_name_; - - // GraphDef representation of the segment. - GraphDef segment_graph_; - - // Lookup table for temporary staging areas of input tensors for calibration. - std::unordered_map> device_buffers_; - - // Temporary staging areas for calibration inputs. - std::vector dev_tensors_; - - // Engine Precision mode. - int precision_mode_; - - // Whether engine is constructed during the conversion or needs to be - // constructed from protobuf segment. - bool static_engine_; - - // Whether to calibrate INT8 engine. - bool calibration_mode_; - - // Whether non-batch ranks of the inputs are assumed to be fixed or not for - // engine construction. - bool fixed_input_size_; - - // Batches of the cached engines - std::vector cached_engine_batches_; - - // Maximum number of cached engines - int max_cached_engines_; - - int64 workspace_size_; - mutex engine_mutex_; - FunctionLibraryRuntime::Handle native_func_; - - // The finalized calibrator for inference. - std::unique_ptr calibrator_; - - // If true, create calibration graph for INT8 mode. Otherwise, we are using - // user-provided quantization ranges. - bool use_calibration_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA - -#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index 7cdfe2b1a612be2eec473d806d0eb44b611ca68a..0cae401023e7d3e3780b9dd2e2a92c9fd0e92db8 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -19,12 +19,6 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long -from tensorflow.contrib.tensorrt.python.ops import trt_engine_op -from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph -from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph -from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value -from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value -from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 203b2697babe32b45523109708cbf062dceee33b..42bf1eda5179d0f72f4fd8432e6b5684f8e46917 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -18,404 +18,40 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six as _six -# pylint: disable=unused-import,line-too-long -from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value -from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert -from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values -from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value -from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version -from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version -from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value -from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled -# pylint: enable=unused-import,line-too-long -from tensorflow.core.framework import graph_pb2 -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.client import session -from tensorflow.python.framework import errors_impl as _impl -from tensorflow.python.framework import graph_util -from tensorflow.python.framework import importer -from tensorflow.python.framework import ops -from tensorflow.python.grappler import tf_optimizer -from tensorflow.python.platform import tf_logging -from tensorflow.python.saved_model import builder -from tensorflow.python.saved_model import loader_impl -from tensorflow.python.saved_model import tag_constants -from tensorflow.python.training import saver - -if _six.PY2: - _to_bytes = lambda s: s - _to_string = lambda s: s -else: - _to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape") - _to_string = lambda s: s.decode("utf-8") - - -class TrtPrecisionMode(object): - FP32 = "FP32" - FP16 = "FP16" - INT8 = "INT8" - - @staticmethod - def supported_precision_modes(): - return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8] - - -def get_tensorrt_rewriter_config(rewriter_config=None, - max_batch_size=1, - max_workspace_size_bytes=2 << 20, - precision_mode=TrtPrecisionMode.FP32, - minimum_segment_size=3, - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=None, - use_calibration=True): - """Returns a RewriterConfig proto for TRT transformation. - - Args: - rewriter_config: a template RewriterConfig proto used to create a - TRT-enabled RewriterConfig. If None, it will use a default one. - max_batch_size: max size for the input batch - max_workspace_size_bytes: the maximum GPU temporary memory which the TRT - engine can use at execution time. This corresponds to the 'workspaceSize' - parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). - precision_mode: one of TrtPrecisionMode.supported_precision_modes(). - minimum_segment_size: the minimum number of nodes required for a subgraph to - be replaced by TRTEngineOp. - is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT - network and engine at run time. - maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. - If the number of cached engines is already at max but none of them can - serve the input, the TRTEngineOp will fall back to run the TF function - based on which the TRTEngineOp is created. - cached_engine_batch_sizes: a list of batch sizes used to create cached - engines, only used when is_dynamic_op is True. The length of the list - should be smaller than maximum_cached_engines, and the dynamic TRT op will - use this list to determine the batch sizes of the cached engines, instead - of making the decision on the fly. This is useful when we know the most - common batch size(s) the application is going to generate. - use_calibration: this argument is ignored if precision_mode is not INT8. If - set to True, a calibration graph will be created to calibrate the missing - ranges. The calibration graph must be converted to an inference graph - using calib_graph_to_infer_graph() after running calibration. if set to - False, quantization nodes will be expected for every tensor in the graph - (exlcuding those which will be fused). If a range is missing, an error - will occur. Please note that accuracy may be negatively affected if there - is a mismatch between which tensors TRT quantizes and which tensors were - trained with fake quantization. - - Returns: - A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. - - Raises: - TypeError: if any of the parameters are of unexpected type. - ValueError: if any of the parameters are of unexpected value. - """ - if rewriter_config is not None and not isinstance( - rewriter_config, rewriter_config_pb2.RewriterConfig): - raise TypeError("rewriter_config should be a RewriterConfig proto.") - - rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() - if rewriter_config is None: - # Layout optimizer may add Const nodes followed by Reshape nodes, thus we - # need to run constant folding again. - rewriter_config_with_trt.optimizers.extend( - ["constfold", "layout", "constfold"]) - rewriter_config_with_trt.meta_optimizer_iterations = ( - rewriter_config_pb2.RewriterConfig.ONE) - else: - rewriter_config_with_trt.CopyFrom(rewriter_config) - - if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(): - raise ValueError(("precision mode '{}' is not supported." - "It should be one of {}").format( - precision_mode, - TrtPrecisionMode.supported_precision_modes)) - - optimizer = rewriter_config_with_trt.custom_optimizers.add() - optimizer.name = "TensorRTOptimizer" - optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size - optimizer.parameter_map["max_batch_size"].i = max_batch_size - optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op - optimizer.parameter_map[ - "max_workspace_size_bytes"].i = max_workspace_size_bytes - optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode) - optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines - if cached_engine_batch_sizes: - if not isinstance(cached_engine_batch_sizes, list): - raise TypeError("cached_engine_batch_sizes should be a list.") - if len(cached_engine_batch_sizes) > maximum_cached_engines: - raise ValueError("cached_engine_batch_sizes should not contain more than " - "maximum_cached_engines items.") - optimizer.parameter_map["cached_engine_batches"].list.i.extend( - cached_engine_batch_sizes) - optimizer.parameter_map["use_calibration"].b = use_calibration - return rewriter_config_with_trt +from tensorflow.python.compiler.tensorrt import trt_convert def create_inference_graph(input_graph_def, outputs, max_batch_size=1, max_workspace_size_bytes=2 << 20, - precision_mode=TrtPrecisionMode.FP32, + precision_mode=trt_convert.TrtPrecisionMode.FP32, minimum_segment_size=3, is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batch_sizes=None, + cached_engine_batches=None, use_calibration=True, input_saved_model_dir=None, input_saved_model_tags=None, output_saved_model_dir=None, session_config=None): - """Python wrapper for the TRT transformation. - - Args: - input_graph_def: a GraphDef object containing a model to be transformed. If - set to None, the graph will be read from the SavedModel loaded from - input_saved_model_dir. - outputs: list of tensors or node names for the model outputs. Only used when - input_graph_def is not None. - max_batch_size: max size for the input batch. - max_workspace_size_bytes: the maximum GPU temporary memory which the TRT - engine can use at execution time. This corresponds to the 'workspaceSize' - parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). - precision_mode: one of TrtPrecisionMode.supported_precision_modes(). - minimum_segment_size: the minimum number of nodes required for a subgraph to - be replaced by TRTEngineOp. - is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT - network and engine at run time. - maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. - If the number of cached engines is already at max but none of them can - serve the input, the TRTEngineOp will fall back to run the TF function - based on which the TRTEngineOp is created. - cached_engine_batch_sizes: a list of batch sizes used to create cached - engines, only used when is_dynamic_op is True. The length of the list - should be smaller than maximum_cached_engines, and the dynamic TRT op will - use this list to determine the batch sizes of the cached engines, instead - of making the decision on the fly. This is useful when we know the most - common batch size(s) the application is going to generate. - use_calibration: this argument is ignored if precision_mode is not INT8. If - set to True, a calibration graph will be created to calibrate the missing - ranges. The calibration graph must be converted to an inference graph - using calib_graph_to_infer_graph() after running calibration. if set to - False, quantization nodes will be expected for every tensor in the graph - (exlcuding those which will be fused). If a range is missing, an error - will occur. Please note that accuracy may be negatively affected if there - is a mismatch between which tensors TRT quantizes and which tensors were - trained with fake quantization. - input_saved_model_dir: the directory to load the SavedModel which contains - the input graph to transforms. Used only when input_graph_def is None. - input_saved_model_tags: list of tags to load the SavedModel. - output_saved_model_dir: if not None, construct a SavedModel using the - returned GraphDef and save it to the specified directory. This option only - works when the input graph is loaded from a SavedModel, i.e. when - input_saved_model_dir is specified and input_graph_def is None. - session_config: the ConfigProto used to create a Session. It's also used as - a template to create a TRT-enabled ConfigProto for conversion. If not - specified, a default ConfigProto will be used. - - Returns: - A GraphDef transformed from input_graph_def (or the SavedModel graph def - loaded from input_saved_model_dir, if input_graph_def is not present), where - all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF - function is added for each of the subgraphs. - - If is_dynamic_op is True, each TRTEngineOp will contain a serialized - subgraph GraphDef, which will be converted to a TRT engine at execution time - and the TRT engine will be cached for future usage. A new TRT engine will be - created each time when none of the cached engines match the input shapes. If - it fails to execute the TRT engine or the number of cached engines reaches - maximum_cached_engines, the op will fall back to call the corresponding TF - function. - - If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT - engine created from the corresponding subgraph. No more engines will be - created on the fly, and the op will fall back to call the corresponding TF - function when it fails to execute the engine. - - Raises: - ValueError: if the combination of the parameters is invalid. - RuntimeError: if the TensorRT library version is incompatible. - """ - compiled_version = get_linked_tensorrt_version() - loaded_version = get_loaded_tensorrt_version() - version_mismatch = False - if loaded_version[0] < compiled_version[0]: - tf_logging.error( - "TensorRT version mismatch. Tensorflow was compiled against " + - "TensorRT %s but library loaded from environment is TensorRT %s" % - (".".join([str(x) for x in compiled_version]), - ".".join([str(x) for x in loaded_version])) + - ". Please make sure that correct version of TensorRT " + - "is available in the system and added to ldconfig or LD_LIBRARY_PATH") - raise RuntimeError("Incompatible TensorRT library version") - for i in zip(loaded_version, compiled_version): - if i[0] != i[1]: - tf_logging.warn("TensorRT mismatch. Compiled against version " + - "%s, but loaded %s. Things may not work" % - (".".join([str(x) for x in compiled_version]), - ".".join([str(x) for x in loaded_version]))) - version_mismatch = True - break - if not version_mismatch: - tf_logging.info("Running against TensorRT version %s" % ".".join( - [str(x) for x in loaded_version])) - - if session_config is None: - session_config = config_pb2.ConfigProto() - - if input_saved_model_tags is None: - input_saved_model_tags = [tag_constants.SERVING] - saved_model_loader = None - grappler_meta_graph_def = None - - if input_graph_def is None: - # Read from SavedModel and freeze the graph if necessary. - if input_saved_model_dir is None: - raise ValueError("input_graph_def and input_saved_model_dir cannot be " - "both None") - with ops.Graph().as_default(): - with session.Session(config=session_config) as sess: - saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir) - input_meta_graph_def = saved_model_loader.load(sess, - input_saved_model_tags) - output_node_names = set() - - def _gather_names(tensor_info): - """Get the node names from a TensorInfo.""" - return set( - [tensor_info[key].name.split(":")[0] for key in tensor_info]) - - # Get input and outputs from all SignatureDef. - for key in input_meta_graph_def.signature_def: - signature_def = input_meta_graph_def.signature_def[key] - output_node_names.update(_gather_names(signature_def.inputs)) - output_node_names.update(_gather_names(signature_def.outputs)) - - # Freeze the variables in the SavedModel graph and copy the frozen - # graph over. - frozen_graph_def = graph_util.convert_variables_to_constants( - sess, sess.graph.as_graph_def(add_shapes=True), - list(output_node_names)) - grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef() - grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def) - - # Copy the collections that are not variables. - for key in input_meta_graph_def.collection_def: - # TODO(laigd): currently we use the collection key to filter out - # collections that depend on variable ops, but this may miss some - # other user-defined collections. A better way would be to use - # CollectionDef::NodeList for the filtering. - if key not in [ - "variables", "local_variables", "model_variables", - "trainable_variables", "train_op", "table_initializer" - ]: - grappler_meta_graph_def.collection_def[key].CopyFrom( - input_meta_graph_def.collection_def[key]) - - # Copy other information. - grappler_meta_graph_def.meta_info_def.CopyFrom( - input_meta_graph_def.meta_info_def) - for key in input_meta_graph_def.signature_def: - grappler_meta_graph_def.signature_def[key].CopyFrom( - input_meta_graph_def.signature_def[key]) - # TODO(laigd): maybe add back AssetFileDef. - else: - if output_saved_model_dir is not None: - raise ValueError("output_saved_model_dir cannot be set when " - "input_graph_def is set") - # Create MetaGraphDef from input graph. - graph = ops.Graph() - with graph.as_default(): - importer.import_graph_def(input_graph_def, name="") - grappler_meta_graph_def = saver.export_meta_graph( - graph_def=graph.as_graph_def(add_shapes=True), graph=graph) - if outputs: - output_collection = meta_graph_pb2.CollectionDef() - output_list = output_collection.node_list.value - for i in outputs: - if isinstance(i, ops.Tensor): - output_list.append(_to_bytes(i.name)) - else: - output_list.append(_to_bytes(i)) - # TODO(laigd): use another key as the outputs are really not train_op. - grappler_meta_graph_def.collection_def["train_op"].CopyFrom( - output_collection) - - # Create TRT-enabled ConfigProto. - session_config_with_trt = config_pb2.ConfigProto() - session_config_with_trt.CopyFrom(session_config) - rewriter_config = None - if (session_config_with_trt.HasField("graph_options") and - session_config_with_trt.graph_options.HasField("rewrite_options")): - rewriter_config = session_config_with_trt.graph_options.rewrite_options - rewriter_config_with_trt = get_tensorrt_rewriter_config( - rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode, - minimum_segment_size, is_dynamic_op, maximum_cached_engines, - cached_engine_batch_sizes, use_calibration) - session_config_with_trt.graph_options.rewrite_options.CopyFrom( - rewriter_config_with_trt) - - # Run Grappler. - transformed_graph_def = tf_optimizer.OptimizeGraph( - session_config_with_trt, grappler_meta_graph_def, graph_id=b"tf_graph") - - # Optionally write the transformed graphdef as SavedModel. - if output_saved_model_dir is not None: - saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) - with ops.Graph().as_default(): - importer.import_graph_def(transformed_graph_def, name="") - # We don't use TRT here. - with session.Session(config=session_config) as sess: - saved_model_builder.add_meta_graph_and_variables( - sess, - input_saved_model_tags, - signature_def_map=grappler_meta_graph_def.signature_def) - # Ignore other meta graphs from the input SavedModel. - saved_model_builder.save() - - return transformed_graph_def + return trt_convert.create_inference_graph( + input_graph_def=input_graph_def, + outputs=outputs, + max_batch_size=max_batch_size, + max_workspace_size_bytes=max_workspace_size_bytes, + precision_mode=precision_mode, + minimum_segment_size=minimum_segment_size, + is_dynamic_op=is_dynamic_op, + maximum_cached_engines=maximum_cached_engines, + cached_engine_batches=cached_engine_batches, + use_calibration=use_calibration, + input_saved_model_dir=input_saved_model_dir, + input_saved_model_tags=input_saved_model_tags, + output_saved_model_dir=output_saved_model_dir, + session_config=session_config) def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): - """Convert an existing calibration graph to inference graph. - - Args: - calibration_graph_def: the calibration GraphDef object with calibration data - is_dynamic_op: whether to create dynamic static engines from calibration - - Returns: - New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. - Raises: - RuntimeError: if the returned status message is malformed. - """ - - is_calib_graph = False - for n in calibration_graph_def.node: - if n.op == "TRTEngineOp": - is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s - if not is_calib_graph: - tf_logging.error( - "Not a calib graph. Doesn't seem to contain any calibration nodes.") - return None - graph_str = calibration_graph_def.SerializeToString() - out = calib_convert(graph_str, is_dynamic_op) - status = _to_string(out[0]) - output_graph_def_string = out[1] - del graph_str # Save some memory - if len(status) < 2: - raise _impl.UnknownError(None, None, status) - if status[:2] != "OK": - msg = status.split(";") - if len(msg) == 1: - raise RuntimeError("Status message is malformed {}".format(status)) - # pylint: disable=protected-access - raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), - int(msg[0])) - # pylint: enable=protected-access - output_graph_def = graph_pb2.GraphDef() - output_graph_def.ParseFromString(output_graph_def_string) - del output_graph_def_string # Save some memory - return output_graph_def + return trt_convert.calib_graph_to_infer_graph( + calibration_graph_def=calibration_graph_def, is_dynamic_op=is_dynamic_op) diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h deleted file mode 100644 index aac9e5c7bd725fc10bcaa04536ebc7be071b4d4c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ -#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/tensorrt/convert/utils.h" -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" -#include "tensorflow/core/framework/resource_mgr.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT - -#include "tensorrt/include/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { - -class TRTCalibrationResource : public tensorflow::ResourceBase { - public: - ~TRTCalibrationResource() { - LOG(INFO) << "Destroying Calibration Resource " << std::endl - << DebugString(); - builder_.reset(); - engine_.reset(); - // We need to manually destroy the builder and engine before the allocator - // is destroyed. - allocator_.reset(); - } - - string DebugString() override { - std::stringstream oss; - using std::dec; - using std::endl; - using std::hex; - oss << " Calibrator = " << hex << calibrator_.get() << dec << endl - << " Builder = " << hex << builder_.get() << dec << endl - << " Engine = " << hex << engine_.get() << dec << endl - << " Logger = " << hex << &logger_ << dec << endl - << " Allocator = " << hex << allocator_.get() << dec << endl - << " Thread = " << hex << thr_.get() << dec << endl; - return oss.str(); - } - - std::unique_ptr calibrator_; - TrtUniquePtrType builder_; - TrtUniquePtrType engine_; - std::unique_ptr allocator_; - tensorflow::tensorrt::Logger logger_; - // TODO(sami): Use threadpool threads! - std::unique_ptr thr_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif -#endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCES_H_ diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index f30dba59ad55317d7ad7730e4dc66c9aba4e6a6b..5c60d6b589ed6a16276226726d989e949bcbf9d7 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorrt/include/NvInfer.h" diff --git a/tensorflow/contrib/tensorrt/test/manual_test.py b/tensorflow/contrib/tensorrt/test/manual_test.py deleted file mode 100644 index 1187c759b4b5483cbf5afe136401abe86d6ef989..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/test/manual_test.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Basic tests for TF-TensorRT integration.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ast -import os - -from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test -from tensorflow.core.framework import graph_pb2 -from tensorflow.python.platform import gfile -from tensorflow.python.platform import test - - -class ManualTest(trt_test.TfTrtIntegrationTestBase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - super(ManualTest, self).__init__(methodName) - self._params_map = None - - def _GetEnv(self): - """Get an environment variable specifying the manual test parameters. - - The value of the environment variable is the string representation of a dict - which should contain the following keys: - - 'graph_path': the file path to the serialized frozen graphdef - - 'input_names': TfTrtIntegrationTestParams.input_names - - 'input_dims': TfTrtIntegrationTestParams.input_dims - - 'expected_output_dims': TfTrtIntegrationTestParams.expected_output_dims - - 'output_name': the name of op to fetch - - 'expected_engines_to_run': ExpectedEnginesToRun() will return this - - 'expected_engines_to_build': ExpectedEnginesToBuild() will return this - - 'max_batch_size': ConversionParams.max_batch_size - - Returns: - The value of the environment variable. - """ - return os.getenv('TRT_MANUAL_TEST_PARAMS', '') - - def _GetParamsMap(self): - """Parse the environment variable as a dict and return it.""" - if self._params_map is None: - self._params_map = ast.literal_eval(self._GetEnv()) - return self._params_map - - def GetParams(self): - """Testing conversion of manually provided frozen graph.""" - params_map = self._GetParamsMap() - gdef = graph_pb2.GraphDef() - with gfile.Open(params_map['graph_path'], 'rb') as f: - gdef.ParseFromString(f.read()) - return trt_test.TfTrtIntegrationTestParams( - gdef=gdef, - input_names=params_map['input_names'], - input_dims=params_map['input_dims'], - output_names=params_map['output_names'], - expected_output_dims=params_map['expected_output_dims']) - - def GetConversionParams(self, run_params): - """Return a ConversionParams for test.""" - conversion_params = super(ManualTest, self).GetConversionParams(run_params) - params_map = self._GetParamsMap() - if 'max_batch_size' in params_map: - conversion_params = conversion_params._replace( - max_batch_size=params_map['max_batch_size']) - return conversion_params - - def ExpectedEnginesToBuild(self, run_params): - """Return the expected engines to build.""" - return self._GetParamsMap()['expected_engines_to_build'] - - def ExpectedEnginesToRun(self, run_params): - """Return the expected engines to run.""" - params_map = self._GetParamsMap() - if 'expected_engines_to_run' in params_map: - return params_map['expected_engines_to_run'] - return self.ExpectedEnginesToBuild(run_params) - - def ExpectedAbsoluteTolerance(self, run_params): - """The absolute tolerance to compare floating point results.""" - params_map = self._GetParamsMap() - if 'atol' in params_map: - return params_map['atol'] - return 1.e-3 - - def ExpectedRelativeTolerance(self, run_params): - """The relative tolerance to compare floating point results.""" - params_map = self._GetParamsMap() - if 'rtol' in params_map: - return params_map['rtol'] - return 1.e-3 - - def ShouldRunTest(self, run_params): - """Whether to run the test.""" - return len(self._GetEnv()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py deleted file mode 100644 index d26f26008635733c6c364a98b72b88c1e552f5fe..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Script to test TF-TensorRT integration.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import numpy as np -import six as _six - -# normally we should do import tensorflow as tf and then -# tf.placeholder, tf.constant, tf.nn.conv2d etc but -# it looks like internal builds don't like it so -# importing every module individually - -from tensorflow.contrib import tensorrt as trt -from tensorflow.core.protobuf import config_pb2 as cpb2 -from tensorflow.core.protobuf import rewriter_config_pb2 as rwpb2 -from tensorflow.python.client import session as csess -from tensorflow.python.framework import constant_op as cop -from tensorflow.python.framework import dtypes as dtypes -from tensorflow.python.framework import importer as importer -from tensorflow.python.framework import ops as ops -from tensorflow.python.ops import array_ops as aops -from tensorflow.python.ops import math_ops as mops -from tensorflow.python.ops import nn as nn -from tensorflow.python.ops import nn_ops as nn_ops - - -def py2bytes(inp): - return inp - - -def py3bytes(inp): - return inp.encode("utf-8", errors="surrogateescape") - - -def py2string(inp): - return inp - - -def py3string(inp): - return inp.decode("utf-8") - - -if _six.PY2: - to_bytes = py2bytes - to_string = py2string -else: - to_bytes = py3bytes - to_string = py3string - - -def get_multi_engine_graph_def(mode="FP32"): - """Create a simple graph and return its graph_def.""" - dtype = dtypes.float32 - if mode.upper() == "FP16": - dtype = dtypes.float16 - else: - pass - - g = ops.Graph() - with g.as_default(): - x = aops.placeholder(shape=[None, 3, 7, 5], name="input", dtype=dtype) - with g.name_scope("Global_scope"): - with g.name_scope("first_scope"): - e = cop.constant( - np.random.randn(3, 2, 3, 4), name="weights", dtype=dtype) - conv = nn.conv2d( - input=x, - filter=e, - data_format="NCHW", - strides=[1, 1, 1, 1], - padding="VALID", - name="conv") - b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias1", dtype=dtype) - t = conv * b - - b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias2", dtype=dtype) - q = conv / b - edge = mops.sin(q) - edge1 = mops.cos(conv) - with g.name_scope("test_scope"): - de = edge + edge1 - t -= edge1 - q *= edge - t += q - t -= de - k = aops.squeeze(t, name="output") - print(k.dtype) - return g.as_graph_def() - - -def get_simple_graph_def(): - """Create a simple graph and return its graph_def.""" - g = ops.Graph() - with g.as_default(): - a = aops.placeholder( - dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") - e = cop.constant( - [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], - name="weights", - dtype=dtypes.float32) - conv = nn.conv2d( - input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") - b = cop.constant( - [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32) - t = nn.bias_add(conv, b, name="biasAdd") - relu = nn.relu(t, "relu") - idty = aops.identity(relu, "ID") - v = nn_ops.max_pool( - idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - aops.squeeze(v, name="output") - return g.as_graph_def() - - -def execute_graph(gdef, dumm_inp): - """Run given graphdef once.""" - print("executing") - gpu_options = None - if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - sessconfig = cpb2.ConfigProto(gpu_options=gpu_options) - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - with csess.Session(config=sessconfig, graph=g) as sess: - val = sess.run(out, {inp: dumm_inp}) - return val - - -# Use real data that is representative of the inference dataset -# for calibration. For this test script it is random data. -def execute_calibration(gdef, dumm_inp): - """Run given calibration graph multiple times.""" - gpu_options = None - if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - with csess.Session( - config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: - # run over real calibration data here, we are mimicking a calibration set of - # 30 different batches. Use as much calibration data as you want - for _ in range(30): - val = sess.run(out, {inp: dumm_inp}) - return val - - -def user(multi_engine, - run_graph=execute_graph, - run_calibration=execute_calibration): - """Example function that converts a graph to TFTRT graph.""" - if multi_engine: - inp_dims = (2, 3, 7, 5) - orig_graph = get_multi_engine_graph_def() - else: - inp_dims = (100, 24, 24, 2) - orig_graph = get_simple_graph_def() # use a frozen graph for inference - dummy_input = np.random.random_sample(inp_dims) - # Get optimized graph - trt_graph = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2, # minimum number of nodes in an engine - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=[]) - o1 = run_graph(orig_graph, dummy_input) - o2 = run_graph(trt_graph, dummy_input) - o3 = run_graph(trt_graph, dummy_input) - assert np.array_equal(o1, o2) - assert np.array_equal(o3, o2) # sanity check - fp16_graph = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2, # minimum number of nodes in an engine - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=[]) - int8_calib_gdef = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2, # minimum number of nodes in an engine - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=[]) - o4 = run_graph(fp16_graph, dummy_input) - _ = run_calibration(int8_calib_gdef, dummy_input) - int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) - o5 = run_graph(int8_graph, dummy_input) - print("Is FP32 == FP16? %s (False is possible)" % np.allclose(o1, o4)) - print("Is FP32 == INT8? %s (False is possible)" % np.allclose(o1, o5)) - print("Pass") - - -def auto(multi_engine): - """Run the conversion as an optimization pass.""" - if multi_engine: - inp_dims = (2, 3, 7, 5) - orig_graph = get_multi_engine_graph_def() - else: - inp_dims = (100, 24, 24, 2) - orig_graph = get_simple_graph_def() # use a frozen graph for inference - dummy_input = np.random.random_sample(inp_dims) - opt_config = rwpb2.RewriterConfig() - opt_config.meta_optimizer_iterations = opt_config.ONE - opt_config.optimizers.extend(["constfold", "layout"]) - custom_op = opt_config.custom_optimizers.add() - custom_op.name = "TensorRTOptimizer" - custom_op.parameter_map["minimum_segment_size"].i = 3 - custom_op.parameter_map["precision_mode"].s = to_bytes("FP32") - custom_op.parameter_map["max_batch_size"].i = inp_dims[0] - custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 - print(custom_op) - gpu_options = None - if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - graph_options = cpb2.GraphOptions(rewrite_options=opt_config) - sessconfig = cpb2.ConfigProto( - gpu_options=gpu_options, graph_options=graph_options) - print(sessconfig) - g = ops.Graph() - ops.reset_default_graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=orig_graph, return_elements=["input", "output"], name="") - inp = inp.outputs[0] - out = out.outputs[0] - with csess.Session(config=sessconfig, graph=g) as sess: - val = sess.run(out, {inp: dummy_input}) - print(val.shape) - - -if "__main__" in __name__: - P = argparse.ArgumentParser( - prog="tftrt_test", - description="Example utilization of TensorFlow-TensorRT integration") - P.add_argument( - "--automatic", - "-a", - action="store_true", - help="Do TRT conversion automatically", - default=False) - P.add_argument( - "--multi-engine", - "-m", - action="store_true", - help="Use a graph that will result in 2 engines", - default=False) - flags, unparsed = P.parse_known_args() - if flags.automatic: - auto(flags.multi_engine) - else: - user(flags.multi_engine) diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md index b29d1acacf17b57549558be45c853566817c1729..f40e76f554e8815aac96344d8cb0b911bafdd712 100644 --- a/tensorflow/contrib/tfprof/README.md +++ b/tensorflow/contrib/tfprof/README.md @@ -1,7 +1,5 @@ # tfprof: TensorFlow Profiler and Beyond -

Please use `tf.profiler.xxx` instead of `tf.contrib.tfprof.xxx`

-

Full Document in tensorflow/core/profiler/README.md

diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 57797214d1684550aa7ad2664b71d22b504f70ed..e10be88ece8ebba9635af955b3c3410f29e5503c 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -105,6 +105,7 @@ py_binary( data = ["data/multivariate_periods.csv"], srcs_version = "PY2AND3", tags = ["no_pip"], + visibility = ["//visibility:public"], deps = select({ ":empty_condition": [], "//conditions:default": [], @@ -113,6 +114,7 @@ py_binary( "//tensorflow:tensorflow_py", "//tensorflow/contrib/timeseries/python/timeseries:estimators", "//tensorflow/contrib/timeseries/python/timeseries:model", + "//tensorflow/contrib/timeseries/python/timeseries:state_management", ], ) diff --git a/tensorflow/contrib/timeseries/examples/predict_test.py b/tensorflow/contrib/timeseries/examples/predict_test.py index 678fd71cd8b94ee0be46e10a9a673de55bd44215..b353f85cb5df0cf961d1900b241e4fa1a84a24b4 100644 --- a/tensorflow/contrib/timeseries/examples/predict_test.py +++ b/tensorflow/contrib/timeseries/examples/predict_test.py @@ -43,10 +43,6 @@ class PeriodTrendExampleTest(test.TestCase): self.assertAllEqual([700], mean.shape) self.assertAllEqual([700], upper_limit.shape) self.assertAllEqual([700], lower_limit.shape) - # Check that variance hasn't blown up too much. This is a relatively good - # indication that training was successful. - self.assertLess(upper_limit[-1] - lower_limit[-1], - 1.5 * (upper_limit[0] - lower_limit[0])) def test_ar(self): (times, observed, all_times, mean, @@ -55,7 +51,6 @@ class PeriodTrendExampleTest(test.TestCase): self.assertAllEqual(all_times.shape, mean.shape) self.assertAllEqual(all_times.shape, upper_limit.shape) self.assertAllEqual(all_times.shape, lower_limit.shape) - self.assertLess((upper_limit - lower_limit).mean(), 4.) if __name__ == "__main__": diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index ae7db35b47b326272dd2c7bc76e18047cec59865..d1be31ddc799ce4c4ef9baa15729fde7925f2f6c 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -104,6 +104,7 @@ py_test( srcs = [ "estimators_test.py", ], + shard_count = 3, srcs_version = "PY2AND3", tags = [ "no_mac", @@ -154,11 +155,11 @@ py_library( py_test( name = "head_test", - size = "large", + size = "medium", srcs = [ "head_test.py", ], - shard_count = 4, + shard_count = 10, srcs_version = "PY2AND3", tags = ["no_pip_gpu"], # b/63391119 deps = [ @@ -280,6 +281,7 @@ py_library( "input_pipeline.py", ], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":feature_keys", ":model_utils", @@ -360,9 +362,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":feature_keys", + ":math_utils", ":model", ":model_utils", - "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index bcadf4094e1e79fff1685515f2bde0b88f717cac..3626701d24163ef52564b42d8a630bd9c5a788eb 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -18,9 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import distributions - from tensorflow.contrib.rnn.python.ops import lstm_ops +from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import model_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures @@ -462,11 +461,12 @@ class ARModel(model.TimeSeriesModel): if self.loss == ARModel.NORMAL_LIKELIHOOD_LOSS: covariance = prediction_ops["covariance"] sigma = math_ops.sqrt(gen_math_ops.maximum(covariance, 1e-5)) - normal = distributions.Normal(loc=targets, scale=sigma) - loss_op = -math_ops.reduce_sum(normal.log_prob(prediction)) + loss_op = -math_ops.reduce_sum( + math_utils.normal_log_prob(targets, sigma, prediction)) else: assert self.loss == ARModel.SQUARED_LOSS, self.loss - loss_op = math_ops.reduce_sum(math_ops.square(prediction - targets)) + loss_op = math_ops.reduce_sum( + math_ops.squared_difference(prediction, targets)) loss_op /= math_ops.cast( math_ops.reduce_prod(array_ops.shape(targets)), loss_op.dtype) return loss_op @@ -965,16 +965,11 @@ class AnomalyMixtureARModel(ARModel): anomaly_variance = prediction_ops["anomaly_params"] anomaly_sigma = math_ops.sqrt( gen_math_ops.maximum(anomaly_variance, 1e-5)) - normal = distributions.Normal(loc=targets, scale=anomaly_sigma) - log_prob = normal.log_prob(prediction) + log_prob = math_utils.normal_log_prob(targets, anomaly_sigma, prediction) else: assert self._anomaly_distribution == AnomalyMixtureARModel.CAUCHY_ANOMALY anomaly_scale = prediction_ops["anomaly_params"] - cauchy = distributions.StudentT( - df=array_ops.ones([], dtype=anomaly_scale.dtype), - loc=targets, - scale=anomaly_scale) - log_prob = cauchy.log_prob(prediction) + log_prob = math_utils.cauchy_log_prob(targets, anomaly_scale, prediction) return log_prob def loss_op(self, targets, prediction_ops): @@ -983,8 +978,7 @@ class AnomalyMixtureARModel(ARModel): covariance = prediction_ops["covariance"] # Normal data log probability. sigma = math_ops.sqrt(gen_math_ops.maximum(covariance, 1e-5)) - normal1 = distributions.Normal(loc=targets, scale=sigma) - log_prob1 = normal1.log_prob(prediction) + log_prob1 = math_utils.normal_log_prob(targets, sigma, prediction) log_prob1 += math_ops.log(1 - self._anomaly_prior_probability) # Anomaly log probability. log_prob2 = self._anomaly_log_prob(targets, prediction_ops) diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index aab330643862c1ccf073d2a0e34e1c475b1ec15f..b7375e5055e29efea3f23c3b9b9f3af59f45495b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -21,6 +21,8 @@ from __future__ import print_function import collections import math +import numpy as np + from tensorflow.contrib import lookup from tensorflow.contrib.layers.python.layers import layers @@ -43,6 +45,32 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest +def normal_log_prob(loc, scale, x): + """Computes the Normal log pdf.""" + z = (x - loc) / scale + return -0.5 * (math_ops.square(z) + + np.log(2. * np.pi) + math_ops.log(scale)) + + +def cauchy_log_prob(loc, scale, x): + """Computes the Cauchy log pdf.""" + z = (x - loc) / scale + return (-np.log(np.pi) - math_ops.log(scale) - + math_ops.log1p(math_ops.square(z))) + + +def mvn_tril_log_prob(loc, scale_tril, x): + """Computes the MVN log pdf under tril scale. Doesn't handle batches.""" + x0 = x - loc + z = linalg_ops.matrix_triangular_solve( + scale_tril, x0[..., array_ops.newaxis])[..., 0] + log_det_cov = 2. * math_ops.reduce_sum(math_ops.log( + array_ops.matrix_diag_part(scale_tril)), axis=-1) + d = math_ops.cast(array_ops.shape(scale_tril)[-1], log_det_cov.dtype) + return -0.5 * (math_ops.reduce_sum(math_ops.square(z), axis=-1) + + d * np.log(2. * np.pi) + log_det_cov) + + def clip_covariance( covariance_matrix, maximum_variance_ratio, minimum_variance): """Enforce constraints on a covariance matrix to improve numerical stability. diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index 125750e7639ad40c481472a93353e6fb7055be96..cf5e749042afd83f927a3d22edfd3a9538ab2ffd 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -78,7 +78,6 @@ py_library( srcs = ["kalman_filter.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/timeseries/python/timeseries:math_utils", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -235,7 +234,6 @@ py_library( srcs = ["filtering_postprocessor.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/timeseries/python/timeseries:math_utils", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py index e9e2ac0aaf4c4d6c41f5007662f261af3de9bbd1..3fa2fbd9f77cb887c30fde264815728ca345f45a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py @@ -22,8 +22,6 @@ import abc import six -from tensorflow.contrib import distributions - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.python.framework import dtypes @@ -91,10 +89,10 @@ def cauchy_alternative_to_gaussian(current_times, current_values, outputs): """ del current_times # unused cauchy_scale = math_utils.entropy_matched_cauchy_scale(outputs["covariance"]) - individual_log_pdfs = distributions.StudentT( - df=array_ops.ones([], dtype=current_values.dtype), + individual_log_pdfs = math_utils.cauchy_log_prob( loc=outputs["mean"], - scale=cauchy_scale).log_prob(current_values) + scale=cauchy_scale, + x=current_values) return math_ops.reduce_sum(individual_log_pdfs, axis=1) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py index a614386121e000961bf8b32625a28e1251654320..c0ec797bc5b7c41ca996c807840ce38311201f87 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import distributions - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.python.framework import dtypes @@ -137,9 +135,10 @@ class KalmanFilter(object): with ops.control_dependencies([non_negative_assert]): observation_covariance_cholesky = linalg_ops.cholesky( symmetrized_observation_covariance) - log_prediction_prob = distributions.MultivariateNormalTriL( - predicted_observation, observation_covariance_cholesky).log_prob( - observation) + log_prediction_prob = math_utils.mvn_tril_log_prob( + loc=predicted_observation, + scale_tril=observation_covariance_cholesky, + x=observation) (posterior_state, posterior_state_var) = self.posterior_from_prior_state( prior_state=estimated_state, diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 05d2ebd2e8a3292a95df0e2f976df0e2871063f8..1859dee9d08ac4a8f3f496222d537b622c65621e 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -1,15 +1,15 @@ # Description: Operations defined for Cloud TPUs -licenses(["notice"]) # Apache 2.0 - load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", + "tf_py_test", ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_py_test") + +licenses(["notice"]) # Apache 2.0 package( default_visibility = [ @@ -61,6 +61,7 @@ py_library( py_library( name = "tpu_estimator", srcs = [ + "python/tpu/_tpu_estimator_embedding.py", "python/tpu/error_handling.py", "python/tpu/tpu_config.py", "python/tpu/tpu_context.py", @@ -70,15 +71,21 @@ py_library( srcs_version = "PY2AND3", deps = [ ":async_checkpoint", + ":feature_column", + ":functional", + ":tpu_embedding", ":tpu_lib", + ":tpu_ordinal_selector_py", "//tensorflow/contrib/training:training_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:function", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:summary_ops_v2", @@ -101,6 +108,8 @@ tf_gen_op_libs( "replication_ops", "tpu_configuration_ops", "tpu_embedding_ops", + "tpu_ordinal_selector_op", + "functional_ops", ], deps = [ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc", @@ -152,6 +161,62 @@ tf_gen_op_wrapper_py( ], ) +tf_custom_op_library( + name = "python/ops/_tpu_ordinal_selector_op.so", + srcs = ["ops/tpu_ordinal_selector_op.cc"], +) + +tf_custom_op_py_library( + name = "tpu_ordinal_selector_py", + srcs = ["python/ops/tpu_ordinal_selector_op.py"], + dso = [":python/ops/_tpu_ordinal_selector_op.so"], + kernels = [ + ":tpu_ordinal_selector_op_op_lib", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":tpu_ordinal_selector_op", + ], +) + +tf_gen_op_wrapper_py( + name = "tpu_ordinal_selector_op", + deps = [ + ":tpu_ordinal_selector_op_op_lib", + ], +) + +tf_custom_op_library( + name = "python/ops/_functional_ops.so", + srcs = ["ops/functional_ops.cc"], +) + +tf_gen_op_wrapper_py( + name = "gen_functional_ops", + out = "python/tpu/gen_functional_ops.py", + hidden = [ + "TPUPartitionedCall", + ], + deps = [":functional_ops_op_lib"], +) + +tf_custom_op_py_library( + name = "functional", + srcs = ["python/tpu/functional.py"], + dso = [":python/ops/_functional_ops.so"], + kernels = [ + ":functional_ops_op_lib", + ], + srcs_version = "PY2AND3", + visibility = [ + "//visibility:public", + ], + deps = [ + ":gen_functional_ops", + ], +) + py_library( name = "profiler", srcs = ["python/profiler/__init__.py"], @@ -166,7 +231,7 @@ py_library( tf_custom_op_py_library( name = "tpu_py", - srcs = glob(["python/ops/*.py"]), + srcs = ["python/ops/tpu_ops.py"], dso = [":python/ops/_tpu_ops.so"], kernels = [ ":all_ops", @@ -192,6 +257,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":feature_column", ":keras_support", # split out to avoid cycle with tpu_strategy ":tpu_embedding", ":tpu_estimator", @@ -211,7 +277,6 @@ py_library( "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", "//third_party/cloud_tpu/models/keras_colab:__subpackages__", - "//third_party/cloud_tpu/models/mnist_keras:__subpackages__", "//third_party/cloud_tpu/models/resnet50_keras:__subpackages__", ], deps = [ @@ -267,6 +332,7 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/compiler:xla", "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", + "//tensorflow/contrib/tpu/proto:dynamic_padding_proto_py", "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py", "//tensorflow/contrib/tpu/proto:topology_proto_py", "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py", @@ -306,13 +372,15 @@ py_library( tf_py_test( name = "datasets_test", + size = "medium", srcs = ["python/tpu/datasets_test.py"], additional_deps = [ "//tensorflow/python:client_testlib", ":datasets", ], - flaky = 1, # TODO(b/117363808): fails 1/1000 OSS runs grpc_enabled = True, + shard_count = 4, + tags = ["no_oss"], ) tf_py_test( @@ -399,7 +467,8 @@ py_library( srcs = ["python/tpu/tpu_embedding.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/tpu:tpu_ops", + ":tpu_lib", + ":tpu_ops", "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", @@ -411,3 +480,37 @@ py_library( "@six_archive//:six", ], ) + +py_library( + name = "feature_column", + srcs = ["python/tpu/feature_column.py"], + deps = [ + ":tpu_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", + ], +) + +tf_py_test( + name = "feature_column_test", + srcs = [ + "python/tpu/feature_column_test.py", + ], + additional_deps = [ + ":feature_column", + "//third_party/py/numpy", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variables", + "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", + ], + main = "python/tpu/feature_column_test.py", +) diff --git a/tensorflow/contrib/tpu/ops/functional_ops.cc b/tensorflow/contrib/tpu/ops/functional_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..aa81e8b24b5e303f5de5d2938b9474fc6b7af6c9 --- /dev/null +++ b/tensorflow/contrib/tpu/ops/functional_ops.cc @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("TPUPartitionedCall") + .Input("args: Tin") + .Input("device_ordinal: int32") + .Output("output: Tout") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .Attr("f: func") + .SetShapeFn(shape_inference::UnknownShape); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/infeed_ops.cc b/tensorflow/contrib/tpu/ops/infeed_ops.cc index efc546f9a6077de9cac5a5acefa3fc7206547fc6..2ed16c2a2270a5399059d7e07f5903e11098bbf9 100644 --- a/tensorflow/contrib/tpu/ops/infeed_ops.cc +++ b/tensorflow/contrib/tpu/ops/infeed_ops.cc @@ -40,6 +40,7 @@ REGISTER_OP("InfeedEnqueue") .Input("input: dtype") .Attr("dtype: type") .Attr("shape: shape = {}") + .Attr("layout: list(int) = []") .Attr("device_ordinal: int = -1") .SetShapeFn(shape_inference::NoOutputs) .SetIsStateful() @@ -49,6 +50,9 @@ An op which feeds a single Tensor value into the computation. input: A tensor that will be provided using the infeed mechanism. dtype: The type of elements in the tensor. shape: The shape of the tensor. +layout: A vector holding the requested layout in minor-to-major sequence. +If a layout attribute is passed, but its values are all -1, the layout will +be computed by the infeed operation. device_ordinal: The TPU device to use. This should be -1 when the Op is running on a TPU device, and >= 0 when the Op is running on the CPU device. @@ -58,6 +62,7 @@ REGISTER_OP("InfeedEnqueueTuple") .Input("inputs: dtypes") .Attr("dtypes: list(type)") .Attr("shapes: list(shape)") + .Attr("layouts: list(int) = []") .Attr("device_ordinal: int = -1") .SetShapeFn(shape_inference::NoOutputs) .SetIsStateful() @@ -67,6 +72,10 @@ An op which feeds multiple Tensor values into the computation as an XLA tuple. inputs: A list of tensors that will be provided using the infeed mechanism. dtypes: The element types of each element in `inputs`. shapes: The shapes of each tensor in `inputs`. +layouts: A vector holding the requested layout in minor-to-major sequence for +all the tuple shapes, in the order the shapes appear in the "shapes" input. +The layout elements for a sub-shape can be set to -1, in which case the +corresponding layout will be computed by the infeed operation. device_ordinal: The TPU device to use. This should be -1 when the Op is running on a TPU device, and >= 0 when the Op is running on the CPU device. diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index 285e11d92de7a684ed87974414ec73c274cc7aa5..d4180d1a20bc59f3fbb37b2dbc67790ded9d2d90 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -31,6 +31,7 @@ REGISTER_OP("TPUReplicateMetadata") // Deprecated. Use num_cores_per_replica instead. .Attr("computation_shape: list(int) = []") .Attr("host_compute_core: list(string) = []") + .Attr("padding_map: list(string) = []") .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("TPUReplicatedInput") @@ -105,6 +106,7 @@ REGISTER_OP("TPUReplicate") .Attr("NumVariables: int >= 0") .Attr("Tguaranteed_constants: list(type) >= 0") .Attr("output_types: list(type) >= 0") + .Attr("padding_map: list(string) = []") .Input("inputs: Tinputs") .Input("broadcast_inputs: Tbroadcast_inputs") .Input("variables: NumVariables * resource") diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc index 0ef29bdf734467aa9dee5c157bc8d8a7e0a85f13..676aed0b7b651494eda80ff2d7c7c31097529590 100644 --- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc +++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc @@ -37,18 +37,18 @@ namespace tensorflow { // pieces of the TF Graph. // 1. Pass this TPUEmbeddingConfiguration to tpu.initialize_system() as the // tpu_embedding_config parameter. -// 2. Use the TPUEmbeddingLoad Op to initialize the embedding tables in TPU +// 2. Use the LoadTPUEmbedding Ops to initialize the embedding tables in TPU // memories, sharded across the memories attached to each Host. -// 3. Use TPUEmbeddingEnqueueSparseBatch to provide the TPU with embedding +// 3. Use EnqueueTPUEmbeddingSparseBatch to provide the TPU with embedding // indices and aggregation weights. -// 4. TPUEmbeddingReceiveActivations returns a list of Tensors, containing the +// 4. RecvTPUEmbeddingActivations returns a list of Tensors, containing the // activations from each table specified in the configuration. // 5. TPUEmbeddingActivations, when used with appropriate Python libraries, // enables the automatic differentiation of models that use embeddings. -// 6. TPUEmbeddingSendGradients takes a list of Tensors (of the same shapes +// 6. SendTPUEmbeddingGradients takes a list of Tensors (of the same shapes // as those returned by TPUEmbeddingReceiveActivations) containing gradients // to use in updating the embedding tables. -// 7. Before saving a checkpoint, use the TPUEmbeddingRetrieve Op to update +// 7. Before saving a checkpoint, use the RetrieveTPUEmbedding Ops to update // the Graph's embedding table Variables from the updated tables in the // TPU memories. // @@ -455,20 +455,21 @@ REGISTER_OP("SendTPUEmbeddingGradients") return Status::OK(); }) .Doc(R"doc( -An op that performs gradient updates of embedding tables. - -The TensorList argument has the same length and shapes as the return value of -TPUEmbeddingReceiveActivations, but contains gradients of the model's loss -with respect to the embedding activations. The embedding tables are updated -from these gradients via the optimizer specified in the configuration given -to tpu.initialize_system. +An op that performs gradient updates of embedding tables using the specified +learning rates. inputs: A TensorList of gradients with which to update embedding tables. - It contains one tensor per embedding table in the model. -learning_rates: A list of float32 scalars, one for each embedding table, - containing the learning rates for each table when dynamic learning rate is - enabled through the OptimizationParameters in TPUEmbeddingConfiguration. - When the learning rate is constant, the list should be empty. + This argument has the same length and shapes as the return value of + RecvTPUEmbeddingActivations, but contains gradients of the model's loss + with respect to the embedding activations. The embedding tables are updated + from these gradients via the optimizer specified in the TPU embedding + configuration given to tpu.initialize_system. +learning_rates: A TensorList of float32 scalars, one for each dynamic learning + rate tag: see the comments in + //third_party/tensorflow/contrib/tpu/proto/optimization_parameters.proto. + Multiple tables can share the same dynamic learning rate tag as specified + in the configuration. If the learning rates for all tables are constant, + this list should be empty. config: Serialized TPUEmbeddingConfiguration proto. )doc"); diff --git a/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc b/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..54e6b20f7f388b67a96ac8acfe814a4202b56a18 --- /dev/null +++ b/tensorflow/contrib/tpu/ops/tpu_ordinal_selector_op.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("TPUOrdinalSelector") + .Output("device_ordinals: int32") + .SetIsStateful() + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, + c->Vector(shape_inference::InferenceContext::kUnknownDim)); + return Status::OK(); + }) + .Doc(R"doc( +A TPU core selector Op. + +This Op produces a set of TPU cores (for warm-up) or a single TPU core +(for regular inference) to execute the TPU program on. The output is +consumed by TPUPartitionedCall. + +device_ordinals: A vector 1 or more TPU cores. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 63641e00c5dbf4b4e635ecfea8bef98c7d0b7075..a081c4354a779d37140338793e66844c3fcf7a12 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -90,12 +90,12 @@ def main(unused_argv=None): tf_version = tf.__version__ print('TensorFlow version %s detected' % tf_version) - if FLAGS.service_addr is None and FLAGS.tpu is None: + if not FLAGS.service_addr and not FLAGS.tpu: sys.exit('You must specify either --service_addr or --tpu.') tpu_cluster_resolver = None - if FLAGS.service_addr is not None: - if FLAGS.tpu is not None: + if FLAGS.service_addr: + if FLAGS.tpu: tf.logging.warn('Both --service_addr and --tpu are set. Ignoring ' '--tpu and using --service_addr.') service_addr = FLAGS.service_addr diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index f27ae38e0434991da7475e631be1c6cb4a463118..807cf26fe983b4ebe17695d6f4f90ecfc0e0cbf5 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -33,7 +33,7 @@ setup( long_description='Tools for capture TPU profile', url='https://www.tensorflow.org/tfrc/', author='Google Inc.', - author_email='opensource@google.com', + author_email='packages@tensorflow.org', packages=['cloud_tpu_profiler'], package_data={ 'cloud_tpu_profiler': ['data/*'], diff --git a/tensorflow/contrib/tpu/profiler/trace_events.proto b/tensorflow/contrib/tpu/profiler/trace_events.proto index cb2b9162677a0ebe8240a98671b1cabc1cee0c9f..96c4784c691d8f34cf8715cdc0ed9886412f5f90 100644 --- a/tensorflow/contrib/tpu/profiler/trace_events.proto +++ b/tensorflow/contrib/tpu/profiler/trace_events.proto @@ -56,4 +56,7 @@ message TraceEvent { // The duration of the event in picoseconds if applicable. // Events without duration are called instant events. uint64 duration_ps = 10; + + // Extra arguments that will be displayed in trace view. + map args = 11; } diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD index c20cab844cfaf083be2702a29ac2a152c7b72c2a..ea98ee25c89e1b7bef39276bae5c98bf382dbd7f 100644 --- a/tensorflow/contrib/tpu/proto/BUILD +++ b/tensorflow/contrib/tpu/proto/BUILD @@ -49,6 +49,15 @@ tf_proto_library( visibility = ["//visibility:public"], ) +tf_proto_library( + name = "dynamic_padding_proto", + srcs = [ + "dynamic_padding.proto", + ], + cc_api_version = 2, + visibility = ["//visibility:public"], +) + tf_proto_library_py( name = "compilation_result_proto", srcs = [ diff --git a/tensorflow/contrib/tpu/proto/dynamic_padding.proto b/tensorflow/contrib/tpu/proto/dynamic_padding.proto new file mode 100644 index 0000000000000000000000000000000000000000..c9ebf181169a583d774ef77ca0b8c243ce733615 --- /dev/null +++ b/tensorflow/contrib/tpu/proto/dynamic_padding.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; + +package tensorflow.tpu; + +// A mapping between the dynamic shape dimension of an input and the arg that +// represents the real shape. +message PaddingMap { + // Input arg index with dynamic shapes. + int32 arg_index = 1; + + // The dynamic shape dimension index. + int32 shape_index = 2; + + // The arg index that dynamic dimension maps to, which represents the value + // of the real shape. + int32 padding_arg_index = 3; +} diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index aae1ab1d37a166303883e3a07a7a01efe2feab51..bc50c613f3d2a09f9e51353fab4938055549a4cd 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -9,9 +9,38 @@ message ClippingLimits { google.protobuf.FloatValue upper = 2; // +inf if not set } -// Get the learning rate from the parameters of the SendTPUEmbeddingGradients -// op. +// Dynamic learning rate specification in the TPUEmbeddingConfiguration. The +// actual learning rates are provided as a scalar input list to the +// SendTPUEmbeddingGradients Op indexed by their tag specified through the +// following proto. message DynamicLearningRate { + // For tables where learning rates are dynamically computed and communicated + // to the TPU embedding program, a tag must be specified for the learning + // rate. + // + // The tag must be a non-negative integer. The total number of unique tags + // must be less than or equal to the number of tables in the TPU embedding + // configuration (a table does not specify any tag if it uses a constant + // learning rate, and specifies exactly one tag if it uses dynamic learning + // rates). + // + // All tags in the range [0, number_of_unique_tags) must be present in the TPU + // embedding configuration, i.e. a tag cannot be skipped if a different tag + // numerically greater than it is used in the configuration. + // + // If multiple tables specify the same tag, they *MUST* have + // the same dynamic learning rate, for example, their dynamic learning rate + // could be computed by the same TensorFlow sub-graph. The partitioning of the + // embedding layer would be more optimal if the number_of_unique_tags is as + // *LOW* as possible, i.e., if many tables share the same tag. + // + // The learning_rate input of the SendTPUEmbeddingGradients op is used to + // communicate dynamic learning rates to the TPU embedding program. + // The learning_rate input is a list of scalars where the size of the list is + // equal to the number of unique tags. The learning rate associated with a + // particular tag is specified by populating its corresponding index in the + // list of learning_rate scalars. + int32 tag = 1; } // Source of learning rate to use. @@ -186,7 +215,8 @@ message OptimizationParameters { } // Specification of an optimization algorithm's state variables (both the main -// value vector and any extra accumulators, etc.). +// value vector and any extra accumulators, etc.). This proto is only used +// internally by the TPU software and is not exposed directly to the TF model. message StateVariableSpecification { // Parameter name for the state variable. string name = 1; @@ -194,6 +224,20 @@ message StateVariableSpecification { // A normal state variable that should be saved and restored in checkpoints // and used as an input or output to non-debug TensorFlow ops. message UserDefined { + // For padding embedding rows, this field specifies the initial value to be + // used. Separate initial values need to be specified for the embeddings and + // any extra accumulators. The initial values should be specified so as to + // maintain two invariants during model training: + // (1) The embedding vector multiplied by zero returns a vector containing + // all zeros. To maintain this invariant, the embedding values should + // never be NaNs or +-infinity. + // (2) Repeatedly applying the optimizer using a gradient vector of all + // zeros does not cause the embeddings or slot variables to become NaNs + // or +-infinity. + // The padding row is looked up when no embedding IDs are present for a + // feature. The semantics of embedding lookup dictate that the output must + // be zero under this scenario. + double padding_initial_value = 1; } // A state variable that should be filled with a constant and normally hidden diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 6a6eba282a12d68cc3cd4e46a46a1b4190fb737b..500dd2cd39d6b8747cebb95d0a01d8c5680427fe 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -157,7 +157,7 @@ if platform.system() != "Windows": _SUPPORTED_INFEED_DTYPES = set([ dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32, - dtypes.complex64 + dtypes.complex64, dtypes.uint32 ]) def infeed_dequeue(dtype, shape, name=None): @@ -217,13 +217,19 @@ if platform.system() != "Windows": Args: inputs: A TensorList of gradients with which to update embedding tables. - Contains one tensor per embedding table in the model. + This argument has the same length and shapes as the return value of + RecvTPUEmbeddingActivations, but contains gradients of the model's + loss with respect to the embedding activations. The embedding tables + are updated from these gradients via the optimizers specified in the + TPU embedding configuration given to tpu.initialize_system. config: Serialized TPUEmbeddingConfiguration proto. - learning_rates: A TensorList of float32 scalars, one for each embedding - table, containing the learning rates for each table when dynamic - learning rate is enabled through the OptimizationParameters in - TPUEmbeddingConfiguration. When the learning rate is constant, the list - should be empty (optional). + learning_rates: A TensorList of float32 scalars, one for each dynamic + learning rate tag: see the comments in + //third_party/tensorflow/contrib/tpu/proto/ + optimization_parameters.proto. + Multiple tables can share the same dynamic learning rate tag as + specified in the configuration. If the learning rates for all tables + are constant, this list should be empty. name: A name for the operation (optional). Returns: @@ -337,9 +343,8 @@ if platform.system() != "Windows": Args: sample_indices: A list of rank 1 Tensors specifying the training example to which the corresponding embedding_indices and aggregation_weights - values - belong. It corresponds to sp_ids.indices[:,0] in - embedding_lookup_sparse(). + values belong. It corresponds to sp_ids.indices[:,0] in + embedding_lookup_sparse(). embedding_indices: A list of rank 1 Tensors, indices into the embedding tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). aggregation_weights: A list of rank 1 Tensors containing per training diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca38cd1bae5753a7398834bd96d3b26e66b4941 --- /dev/null +++ b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py @@ -0,0 +1,38 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Operations to select TPU core to run.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import platform + +if platform.system() != "Windows": + # pylint: disable=wildcard-import,unused-import,g-import-not-at-top + from tensorflow.contrib.tpu.ops.gen_tpu_ordinal_selector_op import * + + from tensorflow.contrib.util import loader + from tensorflow.python.platform import resource_loader + # pylint: enable=wildcard-import,unused-import,g-import-not-at-top + + _tpu_ordinal_selector_op = loader.load_op_library( + resource_loader.get_path_to_datafile("_tpu_ordinal_selector_op.so")) + +else: + # We have already built the appropriate libraries into the binary via CMake + # if we have built contrib, so we don't need this + pass diff --git a/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce96e5bcdbe5777f68eb969be46423b5b3410cb --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py @@ -0,0 +1,273 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""Tooling for support TPU embedding in TPUEstimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.tpu.python.tpu import feature_column as tpu_fc +from tensorflow.contrib.tpu.python.tpu import tpu_embedding +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.feature_column import feature_column as core_fc +from tensorflow.python.feature_column import feature_column_lib as core_fc_lib + +# pylint: disable=protected-access +_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn, + tpu_fc._TPUSharedEmbeddingColumn) +_EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn, + core_fc_lib.EmbeddingColumn, + core_fc._SharedEmbeddingColumn) +_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn) + +# pylint: enable=protected-access + + +def get_tpu_embedding_config_from_feature_columns(feature_columns): + """Create configs for TPUEmbedding from a list of feature columns. + + This function will place one embedding tensor per table and the return is + intended to be used as input to TPUEmbedding. + + Args: + feature_columns: a list of supported feature columns. + + Returns: + A pair of dicts, the first maps tables to their config, the second maps + features to tables. + """ + + allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access + + for column in feature_columns: + if not isinstance(column, allowed): + raise TypeError( + 'Unsupported feature column {}. Supported types are {}.'.format( + type(column), allowed)) + + table_to_config = {} + feature_to_table = {} + for column in feature_columns: + feature_name = column.get_feature_key_name() + table_name = 'tbl_{}'.format(column.get_embedding_var_name()) + if feature_name in feature_to_table: + raise ValueError( + 'Feature column {} is used with multiple embeddings and this is ' + 'not supported.'.format(feature_name)) + feature_to_table[feature_name] = table_name + vocabulary_size, dimension = column.get_embedding_table_size() + table_to_config[table_name] = tpu_embedding.TableConfig( + vocabulary_size=vocabulary_size, + dimension=dimension, + initializer=column.get_initializer(), + combiner=column.get_combiner()) + + return table_to_config, feature_to_table + + +def _get_tpu_embedding_optimization_parameters(embedding_config_spec): + """Get tpu_embedding._OptimizationParameters from EmbeddingConfigSpec.""" + if embedding_config_spec.optimizer_type == 'adagrad': + return tpu_embedding.AdagradParameters( + embedding_config_spec.learning_rate, + embedding_config_spec.adagrad_initial_accumulator, + embedding_config_spec.use_gradient_accumulation) + elif embedding_config_spec.optimizer_type == 'sgd': + return tpu_embedding.StochasticGradientDescentParameters( + embedding_config_spec.learning_rate, + embedding_config_spec.use_gradient_accumulattion) + elif embedding_config_spec.optimizer_type == 'adam': + return tpu_embedding.AdamParameters( + embedding_config_spec.learning_rate, + embedding_config_spec.adam_parameters.beta1, + embedding_config_spec.adam_parameters.beta2, + embedding_config_spec.adam_parameters.epsilon, + use_gradient_accumulation=embedding_config_spec + .use_gradient_accumulation) + else: + raise ValueError('optimizer_type must be adagrad or sgd or adam for now.') + + +AdamParameters = collections.namedtuple('AdamParameters', + ['beta1', 'beta2', 'epsilon']) + + +# TODO(shizhiw): Improve the API to support more optimizer parameters in API. +class EmbeddingConfigSpec( + collections.namedtuple('EmbeddingConfigSpec', [ + 'feature_columns', 'learning_rate', 'optimizer_type', + 'adagrad_initial_accumulator', 'clipping_limit', + 'use_gradient_accumulation', 'adam_parameters' + ])): + """Class to keep track of embedding config specification.""" + + def __new__(cls, + feature_columns, + learning_rate, + optimizer_type='adagrad', + adagrad_initial_accumulator=None, + clipping_limit=None, + use_gradient_accumulation=False, + adam_parameters=None): + """Creates an EmbeddingConfigSpec instance. + + Args: + feature_columns: All `FeatureColumn`s used by model. + learning_rate: embedding optimizer learning rate. + optimizer_type: (String) Name of the optimizer for embedding gradients + updates. Must be either 'adagrad' ( `tf.train.AdagradOptimizer`, default + value), 'sgd' (`tf.train.GradientDescentOptimizer`), or 'adam' + (`tf.contrib.opt.LazyAdamOptimizer`) for lazy Adam. This optimizer will + be applied to all embedding variables specified by `feature_columns`. + adagrad_initial_accumulator: Initial accumulator for Adagrad. Used when + optimizer_type is 'adagrad'. Default is `0.1`. + clipping_limit: (Optional) Clipping limit (absolute value). + use_gradient_accumulation: (Experimental) Whether to accumulate the + gradients across TPU embedding mini-batches. Gradient accumulation does + not affect SGD and therefore this is applicable only for Adagrad. + adam_parameters: AdamParameters. Used when optimizer_type is 'adam'. + Default is 0.9 for beta1, 0.999 for beta2 and 1e-8 for epsilon. + + Returns: + An EmbeddingConfigSpec instance. + + Raises: + ValueError: If the feature_columns are not specified. + TypeError: If the feature columns are not of ths correct type (one of + _SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR + _EMBEDDING_COLUMN_CLASSES). + ValueError: If use_gradient_accumulation is True for SGD. + ValueError: If `optimizer_type` is not one of "adagrad" or "sgd" or + "adam". + """ + if not feature_columns: + raise ValueError('`feature_columns` cannot be `None` or empty.') + + # It is unknown at this moment, whether the TPUEstimator is running in CPU + # or TPU mode. So allow non-TPU embedding columns also. + supported_classes = tuple( + list(_SUPPORTED_FEATURE_COLUMNS) + list(_TPU_EMBEDDING_COLUMN_CLASSES) + + list(_EMBEDDING_COLUMN_CLASSES)) + + for column in feature_columns: + if not isinstance(column, supported_classes): + raise TypeError( + 'All feature columns must be supported types in {}. Got {}'.format( + supported_classes, type(column))) + + if optimizer_type == 'adagrad': + if adagrad_initial_accumulator is None: + adagrad_initial_accumulator = 0.1 + if adagrad_initial_accumulator <= 0: + raise ValueError('Adagrad initial_accumulator must be positive') + elif optimizer_type == 'sgd': + if use_gradient_accumulation: + raise ValueError('Gradient accumulation makes sense for Adagrad only.') + elif optimizer_type == 'adam': + if adam_parameters is None: + adam_parameters = AdamParameters(0.9, 0.999, 1e-8) + if adam_parameters.beta1 < 0. or adam_parameters.beta1 >= 1.: + raise ValueError('beta1 must be between 0. and 1; got {}.'.format( + adam_parameters.beta1)) + if adam_parameters.beta2 < 0. or adam_parameters.beta2 >= 1.: + raise ValueError('beta2 must be between 0. and 1; got {}.'.format( + adam_parameters.beta2)) + if adam_parameters.epsilon <= 0.: + raise ValueError('epsilon must be positive; got {}.'.format( + adam_parameters.epsilon)) + else: + raise ValueError('optimizer_type must be adagrad or sgd or adam for now.') + + return super(EmbeddingConfigSpec, cls).__new__( + cls, + feature_columns=feature_columns, + learning_rate=learning_rate, + optimizer_type=optimizer_type, + adagrad_initial_accumulator=adagrad_initial_accumulator, + clipping_limit=clipping_limit, + use_gradient_accumulation=use_gradient_accumulation, + adam_parameters=adam_parameters) + + +class EmbeddingConfig(object): + """This is the internal immutable object for embedding config. + + `_EmbeddingConfig` is responsible to _translate_ user provided + `EmbeddingConfigSpec` to internal data structures, mostly constructor + arguments of `TPUEmbedding`. + """ + + def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size, + num_hosts, num_cores, master): + self._embedding_config_spec = embedding_config_spec + self._train_batch_size = train_batch_size + self._eval_batch_size = eval_batch_size + self._num_hosts = num_hosts + self._num_cores = num_cores + self._master = master + + self._table_to_config_dict, self._feature_to_table_dict = ( + get_tpu_embedding_config_from_feature_columns( + embedding_config_spec.feature_columns)) + self._optimization_parameters = _get_tpu_embedding_optimization_parameters( + self._embedding_config_spec) + self._mode_to_tpu_embedding_dict = {} + + def has_embedding_tables(self): + return bool(self._table_to_config_dict) + + def _create_tpu_embedding(self, mode): + """Create tpu_embedding.TPUEmbedding based on mode.""" + if mode == model_fn_lib.ModeKeys.TRAIN: + batch_size = self._train_batch_size + else: + batch_size = self._eval_batch_size + + if mode == model_fn_lib.ModeKeys.TRAIN: + tpu_embedding_mode = tpu_embedding.TRAINING + elif (mode == model_fn_lib.ModeKeys.EVAL or + mode == model_fn_lib.ModeKeys.PREDICT): + tpu_embedding_mode = tpu_embedding.INFERENCE + else: + raise ValueError('Mode {} is not supported.'.format(mode)) + + tpu_embedding_ = tpu_embedding.TPUEmbedding( + self._table_to_config_dict, + self._feature_to_table_dict, + batch_size, + tpu_embedding_mode, + self._master, + self._optimization_parameters, + ) + return tpu_embedding_ + + def get_tpu_embedding(self, mode): + if mode not in self._mode_to_tpu_embedding_dict: + self._mode_to_tpu_embedding_dict[mode] = ( + self._create_tpu_embedding(mode)) + return self._mode_to_tpu_embedding_dict[mode] + + +def split_inputs(ctx, features, labels): + """Splits the dense and sparse tensors inside the features and labels.""" + sparse_features = collections.OrderedDict() + if ctx.embedding_config: + tpu_embedding_ = ctx.embedding_config.tpu_embedding + for feature_key in tpu_embedding_.feature_to_table_dict: + sparse_features[feature_key] = features.pop(feature_key) + + return features, labels, sparse_features diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index d61c824eab5337a7cd08cfa52a7e8f8b8d73b455..bc0cd41d210ac6f8de1b20ebf744ee1e1dd04137 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -142,21 +142,18 @@ def StreamingFilesDataset(files, source_dataset = source_dataset.shuffle( buffer_size=filename_shuffle_buffer_size) - # NOTE: We perform the `repeat` on the source dataset, because the output - # dataset does not currently have enough information to recreate an iterator - # over the source dataset when it reaches the end. - source_dataset = source_dataset.repeat(num_epochs) - source_dataset = source_dataset.apply( interleave_ops.parallel_interleave( reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy)) + source_dataset = source_dataset.repeat(num_epochs) + if batch_transfer_size: source_dataset = source_dataset.batch(batch_transfer_size) source_dataset = source_dataset.prefetch(1) - source_iterator = source_dataset.make_one_shot_iterator() + source_iterator = dataset_ops.make_one_shot_iterator(source_dataset) source_handle = source_iterator.string_handle() @function.Defun(dtypes.string) diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py index b58d05eac56f3586e183333f7c1a3867ee57456c..8a94f527bb6dffa48e71e6500ae5e9e9589fbf5c 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -27,6 +27,7 @@ from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.platform import test @@ -55,6 +56,7 @@ class DatasetsTest(test.TestCase): session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def) self._sess = session.Session(self._worker.target, config=session_config) + self._worker_device = '/job:' + worker_job.name def testTextLineDataset(self): all_contents = [] @@ -70,7 +72,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') - iterator = dataset.make_initializable_iterator() + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -94,7 +97,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') - iterator = dataset.make_initializable_iterator() + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -121,7 +125,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') - iterator = dataset.make_initializable_iterator() + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -154,7 +159,8 @@ class DatasetsTest(test.TestCase): os.path.join(self.get_temp_dir(), 'fixed_length*'), filetype=FixedLengthFile) - iterator = dataset.make_initializable_iterator() + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -177,7 +183,8 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( dataset_ops.Dataset.range(10), filetype=gen_dataset) - iterator = dataset.make_initializable_iterator() + with ops.device(self._worker_device): + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py index 6906501ecf90c8e577aa0becf2dba818deb19df4..3313dc749c2c7606101b2dc96614df2d052dfed1 100644 --- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py +++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py @@ -25,6 +25,9 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.tpu.topology import Topology +SINGLE_CORE_ASSIGNMENT = [[[0, 0, 0]]] + + def _compute_task_and_cores_to_replicas(core_assignment, topology): """Computes a nested dict which maps task and logical core to replicas.""" task_and_cores_to_replicas = {} diff --git a/tensorflow/contrib/tpu/python/tpu/feature_column.py b/tensorflow/contrib/tpu/python/tpu/feature_column.py new file mode 100644 index 0000000000000000000000000000000000000000..8edf131bc24fd003806263570b63ee8514c49896 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/feature_column.py @@ -0,0 +1,429 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""TPU Feature Column Library.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope +# pylint: disable=protected-access + + +_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope' +_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn, + fc._VocabularyFileCategoricalColumn, + fc._VocabularyListCategoricalColumn, + fc._WeightedCategoricalColumn, + fc_lib.IdentityCategoricalColumn, + fc_lib.VocabularyFileCategoricalColumn, + fc_lib.VocabularyListCategoricalColumn, + fc_lib.WeightedCategoricalColumn) + + +def embedding_column(categorical_column, + dimension, + combiner='mean', + initializer=None): + """TPU embedding_column for `tf.feature_column.embedding_column`. + + Note that the interface for TPU embedding_column is different from the non-TPU + version. The following args available for the non-TPU version are NOT + supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable. + + Args: + categorical_column: A categorical_column returned from + categorical_column_with_identity, weighted_categorical_column, + categorical_column_with_vocabulary_list or + categorical_column_with_vocabulary_file. + dimension: An integer specifying dimension of the embedding, must be > 0. + combiner: A string specifying how to reduce if there are multiple entries + in a single row. For more information, see + `tf.feature_column.embedding_column`. + initializer: A variable initializer function to be used in embedding + variable initialization. If not specified, defaults to + `tf.truncated_normal_initializer` with mean `0.0` and standard deviation + `1/sqrt(dimension)`. + + Returns: + A _TPUEmbeddingColumn. + + Raises: + ValueError: if `dimension` not > 0. + ValueError: if `initializer` is specified but not callable. + """ + if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): + raise TypeError( + 'categorical_column for tpu ' + ' embedding_column must be type %s, got %s.' % (' or '.join([ + cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS + ]), type(categorical_column))) + if (dimension is None) or (dimension < 1): + raise ValueError('Invalid dimension {}.'.format(dimension)) + + if (initializer is not None) and (not callable(initializer)): + raise ValueError('initializer must be callable if specified. ' + 'Embedding of column_name: {}'.format( + categorical_column.name)) + if initializer is None: + initializer = init_ops.truncated_normal_initializer( + mean=0.0, stddev=1 / math.sqrt(dimension)) + + embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access + + def _creator(weight_collections, scope): + embedding_column_layer = fc._EmbeddingColumnLayer( + embedding_shape=embedding_shape, + initializer=initializer, + weight_collections=weight_collections, + trainable=True, + name='embedding_column_layer') + return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable + + column = _TPUEmbeddingColumn( + categorical_column=categorical_column, + dimension=dimension, + combiner=combiner, + layer_creator=_creator, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True) + # For Embedding column, the initializer is hidden inside the creator Fn, which + # is not accessiable later. So, we attach it to a speicial field. Also note + # that non-TPU Embedding column and non-TPU shared Embedding column handle the + # initializer differently. See shared_embedding_columns for details. + column._tpu_initializer = initializer + return column + + +def shared_embedding_columns(categorical_columns, + dimension, + combiner='mean', + initializer=None, + shared_embedding_collection_name=None): + """List of dense columns that convert from sparse, categorical input.""" + for categorical_column in categorical_columns: + if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): + raise TypeError( + 'categorical_column for tpu ' + ' shared_embedding_columns must be type %s, got %s.' % (' or '.join([ + cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS + ]), type(categorical_column))) + columns = fc_lib.shared_embedding_columns( + categorical_columns, + dimension, + combiner=combiner, + initializer=initializer, + shared_embedding_collection_name=shared_embedding_collection_name, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True) + + # Use the initializer and shared_embedding_collection_name to create TPU + # version + initializer = columns[0].initializer + shared_embedding_collection_name = columns[0].shared_embedding_collection_name + tpu_columns = [] + + # Create the state (_SharedEmbeddingColumnLayer) here. + for categorical_column in categorical_columns: + column = _TPUSharedEmbeddingColumn( + categorical_column=categorical_column, + dimension=dimension, + combiner=combiner, + initializer=initializer, + shared_embedding_collection_name=shared_embedding_collection_name, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True) + tpu_columns.append(column) + + return tpu_columns + + +class _TPUBaseEmbeddingColumn(object): + """Base class for TPU Embedding Column.""" + + def __init__(self, categorical_column): + self._tpu_categorical_column = categorical_column + + def get_combiner(self): + """Returns the embedding combiner.""" + raise NotImplementedError('not implemented') + + def get_embedding_table_size(self): + """Returns the embedding table size, tuple of vocab size and dimension.""" + raise NotImplementedError('not implemented') + + def get_feature_key_name(self): + """Returns the feature key name in the features dict.""" + raise NotImplementedError('not impl') + + def get_weight_key_name(self): + """Return the key name for weights.""" + raise NotImplementedError('not impl') + + def get_embedding_var_name(self): + """Returns the embedding variable name. + + Feature key name and embedding variable name are usually one-to-one mapping. + But for shared embedding columns, it is many-to-one mapping. + """ + raise NotImplementedError('not impl') + + def get_initializer(self): + """Returns the initializer.""" + raise NotImplementedError('not impl') + + def is_categorical_column_weighted(self): + """Check if the categorical column of the embedding column is weighted.""" + raise NotImplementedError('not impl') + + +class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): + """Core Embedding Column.""" + + def __new__(cls, + categorical_column, + dimension, + combiner='mean', + layer_creator=None, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True): + # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable + # are not supported on TPU. They are solely for matching the signature of + # __new__ of parent class fc._EmbeddingColumn. + return fc._EmbeddingColumn.__new__( + cls, + categorical_column, + dimension, + combiner=combiner, + layer_creator=layer_creator, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable) + + def __init__(self, + categorical_column, + dimension, + combiner='mean', + layer_creator=None, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True): + _TPUBaseEmbeddingColumn.__init__(self, categorical_column) + self._key = None + + def get_combiner(self): + return self.combiner + + def get_embedding_table_size(self): + """Returns num_ids and width.""" + return (self.categorical_column._num_buckets, self.dimension) + + def get_feature_key_name(self): + """get_feature_key_name.""" + if self.is_categorical_column_weighted(): + return self.categorical_column.categorical_column.name + return self.categorical_column.name + + def get_weight_key_name(self): + """get_weight_key_name.""" + if self.is_categorical_column_weighted(): + return self.categorical_column.weight_feature_key + return None + + def get_embedding_var_name(self): + """get_embedding_var_name.""" + return self.categorical_column.name + + def get_initializer(self): + return self._tpu_initializer + + def is_categorical_column_weighted(self): + """Check if the categorical column of the embedding column is weighted.""" + if isinstance( + self.categorical_column, + ( + fc._WeightedCategoricalColumn, # pylint: disable=protected-access + fc_lib.WeightedCategoricalColumn)): + return True + return False + + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + if tpu.under_tpu_inference_context(): + def host_computation(): + return fc._EmbeddingColumn._get_dense_tensor( + self, inputs, weight_collections, trainable) + return tpu.outside_compilation(host_computation) + + if _is_running_on_cpu(): + return fc._EmbeddingColumn._get_dense_tensor( + self, inputs, weight_collections, trainable) + + # TPU mode + # Get the embeddings from the LazyBuilder. + tensor = inputs.get(self.get_feature_key_name()) + + # Add to collection for _create_tpu_embedding_variables_and_ops + _record_variable_scope_and_name(self.get_embedding_var_name(), + 'embedding_weights') + + return tensor + + +class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, + fc._SharedEmbeddingColumn): + """Core Shared Embedding Column.""" + + def __new__(cls, + categorical_column, + dimension, + combiner='mean', + initializer=None, + shared_embedding_collection_name=None, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True): + return fc._SharedEmbeddingColumn.__new__( + cls, + categorical_column, + dimension, + combiner=combiner, + initializer=initializer, + shared_embedding_collection_name=shared_embedding_collection_name, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable) + + def __init__(self, + categorical_column, + dimension, + combiner='mean', + initializer=None, + shared_embedding_collection_name=None, + ckpt_to_load_from=None, + tensor_name_in_ckpt=None, + max_norm=None, + trainable=True): + + _TPUBaseEmbeddingColumn.__init__(self, categorical_column) + self._key = None + + def get_combiner(self): + return self.combiner + + def get_embedding_table_size(self): + """Returns num_ids and width.""" + return (self.categorical_column._num_buckets, self.dimension) + + def get_feature_key_name(self): + """get_feature_key_name.""" + if self.is_categorical_column_weighted(): + return self.categorical_column.categorical_column.name + return self.categorical_column.name + + def get_weight_key_name(self): + """get_weight_key_name.""" + if self.is_categorical_column_weighted(): + return self.categorical_column.weight_feature_key + return None + + def get_embedding_var_name(self): + """get_embedding_var_name.""" + return self.shared_embedding_collection_name + + def get_initializer(self): + return self.initializer + + def is_categorical_column_weighted(self): + """Check if the categorical column of the embedding column is weighted.""" + if isinstance( + self.categorical_column, + ( + fc._WeightedCategoricalColumn, # pylint: disable=protected-access + fc_lib.WeightedCategoricalColumn)): + return True + return False + + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + if tpu.under_tpu_inference_context(): + def host_computation(): + return fc._SharedEmbeddingColumn._get_dense_tensor( + self, inputs, weight_collections, trainable) + return tpu.outside_compilation(host_computation) + + if _is_running_on_cpu(): + return fc._SharedEmbeddingColumn._get_dense_tensor( + self, inputs, weight_collections, trainable) + + # TPU mode + # Get the embeddings from the LazyBuilder. + tensor = inputs.get(self.get_feature_key_name()) + + # Add to collection for _create_tpu_embedding_variables_and_ops + _record_variable_scope_and_name( + self.get_embedding_var_name(), + 'embedding_weights', + is_shared_embedding=True) + return tensor + + +def _record_variable_scope_and_name(embedding_var_name, + embedding_var_name_in_fc, + is_shared_embedding=False): + """Add embedding variable name and scope to collection.""" + g = ops.get_default_graph() + collection = g.get_collection_ref(_TPU_FC_TO_SCOPE) + if not collection: + collection.append({}) + + var_def_dict = collection[0] + + captured_scope = None + + if is_shared_embedding and (embedding_var_name in var_def_dict): + if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc: + raise ValueError( + 'For embedding var name {}, the shared embedding name is different, ' + 'got {}; expected {}'.format(embedding_var_name, + embedding_var_name_in_fc, + var_def_dict[embedding_var_name][1])) + else: + # scope contains var_scope_name. + captured_scope = variable_scope.get_variable_scope() + var_def_dict[embedding_var_name] = (captured_scope, + embedding_var_name_in_fc) + + +def _is_running_on_cpu(): + """Returns True if the current context is CPU model.""" + return tpu_function.get_tpu_context().number_of_shards is None diff --git a/tensorflow/contrib/tpu/python/tpu/feature_column_test.py b/tensorflow/contrib/tpu/python/tpu/feature_column_test.py new file mode 100644 index 0000000000000000000000000000000000000000..75164cce4c261cc541dd6b01ee22699d286d9621 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/feature_column_test.py @@ -0,0 +1,286 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""Tests for contrib.tpu.python.tpu.feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tpu.python.tpu import feature_column as tpu_fc +from tensorflow.python.client import session +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import test + + +def _initialized_session(): + sess = session.Session() + sess.run(variables_lib.global_variables_initializer()) + sess.run(lookup_ops.tables_initializer()) + return sess + + +class EmbeddingColumnTest(test.TestCase): + + def test_defaults(self): + categorical_column = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_dimension = 2 + embedding_column = tpu_fc.embedding_column( + categorical_column, dimension=embedding_dimension) + self.assertIs(categorical_column, embedding_column.categorical_column) + self.assertEqual(embedding_dimension, embedding_column.dimension) + self.assertEqual('mean', embedding_column.combiner) + self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual('aaa_embedding', embedding_column._var_scope_name) + self.assertEqual((embedding_dimension,), embedding_column._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column._parse_example_spec) + + def test_all_constructor_args(self): + categorical_column = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_dimension = 2 + embedding_column = tpu_fc.embedding_column( + categorical_column, + dimension=embedding_dimension, + combiner='my_combiner', + initializer=lambda: 'my_initializer') + self.assertIs(categorical_column, embedding_column.categorical_column) + self.assertEqual(embedding_dimension, embedding_column.dimension) + self.assertEqual('my_combiner', embedding_column.combiner) + self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual('aaa_embedding', embedding_column._var_scope_name) + self.assertEqual((embedding_dimension,), embedding_column._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column._parse_example_spec) + + def test_get_dense_tensor(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = tpu_fc.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer) + + # Provide sparse input and get dense result. + embedding_lookup = embedding_column._get_dense_tensor( + fc._LazyBuilder({ + 'aaa': sparse_input + })) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, embedding_lookup.eval()) + + +class SharedEmbeddingColumnTest(test.TestCase): + + def test_defaults(self): + categorical_column_a = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc_lib.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_b, embedding_column_a = tpu_fc.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension) + self.assertIs(categorical_column_a, embedding_column_a.categorical_column) + self.assertIs(categorical_column_b, embedding_column_b.categorical_column) + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual(embedding_dimension, embedding_column_b.dimension) + self.assertEqual('mean', embedding_column_a.combiner) + self.assertEqual('mean', embedding_column_b.combiner) + self.assertIsNotNone(embedding_column_a.initializer) + self.assertIsNotNone(embedding_column_b.initializer) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_a.shared_embedding_collection_name) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_b.shared_embedding_collection_name) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual('bbb_shared_embedding', embedding_column_b.name) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_a._var_scope_name) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_b._var_scope_name) + self.assertEqual((embedding_dimension,), embedding_column_a._variable_shape) + self.assertEqual((embedding_dimension,), embedding_column_b._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a._parse_example_spec) + self.assertEqual({ + 'bbb': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_b._parse_example_spec) + + def test_all_constructor_args(self): + categorical_column_a = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc_lib.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + combiner='my_combiner', + initializer=lambda: 'my_initializer', + shared_embedding_collection_name='var_scope_name') + self.assertIs(categorical_column_a, embedding_column_a.categorical_column) + self.assertIs(categorical_column_b, embedding_column_b.categorical_column) + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual(embedding_dimension, embedding_column_b.dimension) + self.assertEqual('my_combiner', embedding_column_a.combiner) + self.assertEqual('my_combiner', embedding_column_b.combiner) + self.assertEqual('my_initializer', embedding_column_a.initializer()) + self.assertEqual('my_initializer', embedding_column_b.initializer()) + self.assertEqual('var_scope_name', + embedding_column_a.shared_embedding_collection_name) + self.assertEqual('var_scope_name', + embedding_column_b.shared_embedding_collection_name) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual('bbb_shared_embedding', embedding_column_b.name) + self.assertEqual('var_scope_name', embedding_column_a._var_scope_name) + self.assertEqual('var_scope_name', embedding_column_b._var_scope_name) + self.assertEqual((embedding_dimension,), embedding_column_a._variable_shape) + self.assertEqual((embedding_dimension,), embedding_column_b._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a._parse_example_spec) + self.assertEqual({ + 'bbb': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_b._parse_example_spec) + + def test_get_dense_tensor(self): + # Inputs. + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array([ + [2, -1, -1], # example 0, ids [2] + [0, 1, -1] + ]) # example 1, ids [0, 1] + input_b = np.array([ + [0, -1, -1], # example 0, ids [0] + [-1, -1, -1] + ]) # example 1, ids [] + input_features = {'aaa': input_a, 'bbb': input_b} + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups_a = ( + # example 0: + (7., 11.), # ids [2], embedding = [7, 11] + # example 1: + (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + ) + expected_lookups_b = ( + # example 0: + (1., 2.), # ids [0], embedding = [1, 2] + # example 1: + (0., 0.), # ids [], embedding = [0, 0] + ) + + # Build columns. + categorical_column_a = fc_lib.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc_lib.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + # Provide sparse input and get dense result. + embedding_lookup_a = embedding_column_a._get_dense_tensor( + fc._LazyBuilder(input_features)) + embedding_lookup_b = embedding_column_b._get_dense_tensor( + fc._LazyBuilder(input_features)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + embedding_var = global_vars[0] + with _initialized_session(): + self.assertAllEqual(embedding_values, embedding_var.eval()) + self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval()) + self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/functional.py b/tensorflow/contrib/tpu/python/tpu/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..24c85156e53a9b770f811c4cf3b903eab6553c76 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/functional.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Functional operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import platform + +from tensorflow.contrib.tpu.python.tpu import gen_functional_ops + + +TPUPartitionedCall = gen_functional_ops._tpu_partitioned_call # pylint: disable=invalid-name,protected-access + + +if platform.system() != "Windows": + # pylint: disable=wildcard-import,unused-import,g-import-not-at-top + from tensorflow.contrib.tpu.ops.gen_tpu_ordinal_selector_op import * + + from tensorflow.contrib.util import loader + from tensorflow.python.platform import resource_loader + # pylint: enable=wildcard-import,unused-import,g-import-not-at-top + + _tpu_partitioned_call_op = loader.load_op_library( + resource_loader.get_path_to_datafile("../ops/_functional_ops.so") + ) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index cf3b2e68e940652220983c98e3a0acb68cf88d89..37fe9af8c4b154a2e20a957f6ca5d97df3d413be 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -133,7 +133,7 @@ def _tpu_session_context(): An error occurred connecting or initializing your TPU. The session has been reset. re-run keras_to_tpu_model to create a new session. -""" + e) +""" + str(e)) def setup_tpu_session(cluster_resolver): @@ -729,7 +729,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): dummy_x_shape[0] *= tpu_assignment.num_towers dummy_y_shape = dataset.output_shapes[1].as_list() dummy_y_shape[0] *= tpu_assignment.num_towers - self._iterator = dataset.make_initializable_iterator() + self._iterator = dataset_ops.make_initializable_iterator(dataset) K.get_session().run(self._iterator.initializer) self._get_next_ops = [] @@ -1373,6 +1373,10 @@ class KerasTPUModel(models.Model): # not hashable. self._numpy_to_infeed_manager_list = [] + # Add distribution specific arguments since we don't call the Model init. + self._distribution_strategy = None + self._compile_distribution = None + self.predict_function = None self.test_function = None self.train_function = None @@ -1676,14 +1680,10 @@ class KerasTPUModel(models.Model): callbacks, self, do_validation=do_validation, - val_inputs=val_inputs, - val_targets=val_targets, - val_sample_weights=val_sample_weights, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, samples=num_training_samples, - validation_steps=validation_steps, verbose=verbose, count_mode=count_mode) @@ -2073,6 +2073,8 @@ class KerasTPUModel(models.Model): # tpu_model may not be compiled, e.g., loading weights and then predict. return for k, v in six.iteritems(cpu_optimizer_config): + if k == 'name': + continue opt_var = getattr(self._tpu_model.optimizer, k) if isinstance(opt_var, variables.Variable): logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var)) @@ -2101,6 +2103,8 @@ class KerasTPUModel(models.Model): self._cpu_model.set_weights(tpu_weights) for k, v in six.iteritems(tpu_optimizer_config): logging.info('TPU -> CPU %s: %s', k, v) + if k == 'name': + continue opt_var = getattr(self.cpu_optimizer, k) if isinstance(opt_var, variables.Variable): K.get_session().run(opt_var.assign(v)) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index 8b0b240dc7302c203a22349d583323327fc4480b..de425626c813784ef657d17eac0c7bb77599a155 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -69,6 +69,7 @@ class ReplicatedVariable(object): def __init__(self, name, variables): self._name = name self._primary_var = variables[0] + self._common_name = self._primary_var.name.split(":")[0] self._vars = variables self._cached_value = None self._dtype = variables[0].dtype diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index a95275487899c4770ef99b620a7671eec2bb81eb..5cb2ca6478a1d7589cd2aa2d52c82306b3fd11f4 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -43,12 +43,19 @@ class CoordinatorShutdownException(Exception): pass +def _clone_session(session, graph=None): + return session_lib.Session( + target=session.sess_str, + config=session._config, # pylint: disable=protected-access + graph=graph if graph else session.graph) + + def _make_heartbeat_op(session, device, request_ph): """Return a heartbeat op or None if heartbeats are not supported by device.""" try: # Test if we can connect in a isolated graph + session with ops.Graph().as_default(): - with session_lib.Session(target=session.sess_str) as temp_session: + with _clone_session(session) as temp_session: with ops.device(device): heartbeat_op = tpu_ops.worker_heartbeat('') options = config_pb2.RunOptions(timeout_in_ms=5000) @@ -165,7 +172,8 @@ class WorkerHeartbeatManager(object): """Shutdown all workers after `shutdown_timeout_secs`.""" logging.info('Shutting down %s.', self) req = event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms)) + watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms), + shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR) self.configure(req) # Wait for workers to shutdown. This isn't strictly required @@ -178,7 +186,8 @@ def all_worker_devices(session): """Return a list of devices for each worker in the system.""" devices = session.list_devices() return [ - device.name for device in devices + device.name + for device in devices if ':CPU:' in device.name and 'coordinator' not in device.name ] @@ -220,6 +229,7 @@ class WatchdogManager(threading.Thread): self.ping_interval = ping_interval self.shutdown_timeout = shutdown_timeout self.daemon = True + self._config = session._config # pylint: disable=protected-access self._target = session.sess_str self._running = False self._devices = devices @@ -234,6 +244,7 @@ class WatchdogManager(threading.Thread): self._session = session_lib.Session( target=self._target, graph=self._graph, + config=self._config, ) if self._devices is None: @@ -246,12 +257,14 @@ class WatchdogManager(threading.Thread): self._worker_manager.configure( event_pb2.WorkerHeartbeatRequest( watchdog_config=event_pb2.WatchdogConfig( - timeout_ms=self.shutdown_timeout * 1000,))) + timeout_ms=self.shutdown_timeout * 1000,), + shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) def configure_and_run(self): - logging.info('Enabling watchdog timer with %d second timeout ' - 'and %d second ping interval.', - self.shutdown_timeout, self.ping_interval) + logging.info( + 'Enabling watchdog timer with %d second timeout ' + 'and %d second ping interval.', self.shutdown_timeout, + self.ping_interval) self._reset_manager() self._running = True self.start() @@ -260,7 +273,8 @@ class WatchdogManager(threading.Thread): logging.info('Stopping worker watchdog.') self._worker_manager.configure( event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,))) + watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,), + shutdown_mode=event_pb2.NOT_CONFIGURED)) self._running = False self.join() @@ -334,8 +348,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): with self._graph.as_default(): logging.info('Installing graceful shutdown hook.') - self._session = session_lib.Session( - target=training_session.sess_str, graph=self._graph) + self._session = _clone_session(training_session, self._graph) self._workers = WorkerHeartbeatManager.from_devices( self._session, all_worker_devices(self._session)) self._heartbeat_supported = self._workers.num_workers() > 0 diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py index 70baea203cc6174bebc7d90646045efae5f2391d..2c5ea65182e404ec44b24bcd7d0f412c04f1beb1 100644 --- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -21,58 +21,211 @@ from __future__ import print_function import os import os.path import re +import sys from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging _TRACER_LOG_PREFIX = ' [>>>TT>>>]' _DEVICE_TYPE_TPU = 'tpu' _DEVICE_TYPE_CPU = 'cpu' -_GLOBAL_STEP_OP_NAME = 'GLOBAL-STEP' _TRACE_MODE_NAN_INF = 'nan-inf' _TRACE_MODE_PART_TENSOR = 'part-tensor' _TRACE_MODE_PART_TENSOR_SIZE = 3 _TRACE_MODE_FULL_TENSOR = 'full-tensor' -_RECORD_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' -_RECORD_SHOULD_NOT_TRACE = 'not-traced-should-not-trace' -_RECORD_FILTERED_OUT = 'not-traced-filtered-out' -_RECORD_SCALAR = 'not-traced-scalar' -_RECORD_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' -_RECORD_GET_TRACED = 'get-traced' +_TRACE_MODE_NORM = 'norm' +_TRACE_MODE_MAX_ABS = 'max-abs' +_SUBMODE_BRIEF = 'brief' +_SUBMODE_DETAILED = 'detailed' +_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' +_REASON_UNSAFE_OP = 'not-traced-unsafe-op' +_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' +_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' +_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' +_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' +_REASON_SCALAR_GET_TRACED = 'traced-scalar' +_REASON_TENSOR_GET_TRACED = 'traced-tensor' +_REASON_USER_INCLUDED = 'traced-user-included' +_REASON_USER_EXCLUDED = 'not-traced-user-excluded' +_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path' +_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' _MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' _MARKER_SECTION_END = '!!!!!!! section-end:' _SECTION_NAME_CONFIG = 'configuration' _SECTION_NAME_REASON = 'reason' _SECTION_NAME_OP_LIST = 'op-list' +_SECTION_NAME_TENSOR_LIST = 'tensor-list' +_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map' _SECTION_NAME_GRAPH = 'graph' _FIELD_NAME_VERSION = 'version:' _FIELD_NAME_DEVICE = 'device:' _FIELD_NAME_TRACE_MODE = 'trace-mode:' +_FIELD_NAME_SUBMODE = 'submode:' _FIELD_NAME_NUM_REPLICAS = 'num-replicas:' +_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:' +_FIELD_NAME_NUM_HOSTS = 'num-hosts:' _FIELD_NAME_NUM_OPS = 'number-of-ops:' +_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:' +_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:' _FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' _FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' _FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") _FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') _FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') +_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*') _FLAG_NAME_ENABLE = 'enable' _FLAG_NAME_TRACE_MODE = 'trace_mode' -_FLAG_NAME_INTERESTING_OPS = 'interesting_ops' -_FLAG_NAME_TRACE_FILE = 'trace_file_path' +_FLAG_NAME_USE_COMPACT_TRACE = 'compact_trace' +_FLAG_NAME_SUBMODE = 'submode' +_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' +_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' +_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' +_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' +_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' +_FLAG_NAME_TRACE_DIR = 'trace_dir' +_FLAG_NAME_REPORT_FILE = 'report_file' _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' _FLAG_NAME_OP_RANGE = 'op_range' _OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') _OUTPUT_STREAM_ESCAPE = 'file://' _TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' +_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' +_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint' +_TRACE_FILE_NAME = 'trace.all' +_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.' +_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0 +_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage' +_TENSOR_VALUES_CACHE = 'tensor_values_cache' +_REPLICA_ID_TAG = '#replica-id: ' + +def tensor_tracepoint(tensor, checkpoint_name): + """Adds a checkpoint with the given checkpoint name for the given tensor. + + The tensor will be added to the list of tensors that will be traced by the + tensor tracer. + + Args: + tensor: the tensor object for which the tracing is requested. + checkpoint_name: a string name for the checkpoint. This name has to be a + unique name if used within model comparison. The tensors that have the same + checkpoint identifier is compared in model comparison. + Returns: + The provided tensor. + """ + + tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) + tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, + (tensor, checkpoint_name)) + return tensor + + +def keras_layer_tracepoint(layer, checkpoint_name): + """An interface for adding the tensor outputs of a keras layer. + + Encapsulates tensor_tracepoint. + + Args: + layer: A keras layer. + checkpoint_name: a string name for the checkpoint. This name has to be a + unique name if used within model comparison. The tensors that have the same + checkpoint identifier is compared in model comparison. + + Returns: + The provided layer. + """ + try: + outputs = layer.output + if tensor_util.is_tensor(outputs): + tensor_tracepoint(outputs, '%s' % (checkpoint_name)) + else: + idx = 0 + for output_tensor in outputs: + if tensor_util.is_tensor(outputs): + tensor_tracepoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) + idx += 1 + except AttributeError: + pass + except RuntimeError: + pass + return layer + + +def _trace_files_need_precreated(output_dir): + """Return True if trace files must be pre-created by users.""" + + if not output_dir.startswith('/'): + return False + if len(output_dir) < 5: + return False + if output_dir[2] != 'n': + return False + if output_dir[3] != 's': + return False + if output_dir[1] != 'c': + return False + if output_dir[4] != '/': + return False + return True + + +def _get_tensor_values_cache(graph=None): + """Returns the variable that implements tensor-value caching.""" + + graph = graph or ops.get_default_graph() + collection = graph.get_collection(_TENSOR_TRACER_STORAGE) + if len(collection) == 1: + return collection[0] + elif not collection: + raise RuntimeError('%s has not been created'%_TENSOR_VALUES_CACHE) + else: + raise RuntimeError('Multiple %s created'%_TENSOR_VALUES_CACHE) + return None + + +def _create_tensor_values_cache(graph, num_tensors): + """Creates a variable as the cache to store intermediate tensor values.""" + + graph = graph or ops.get_default_graph() + # Create in proper graph and base name_scope. + with graph.as_default() as g, g.name_scope(None): + return variable_scope.get_variable( + _TENSOR_VALUES_CACHE, + shape=[num_tensors], + dtype=dtypes.float32, + initializer=init_ops.constant_initializer( + _COMPACT_TRACE_ENTRY_INIT_VALUE), + trainable=False, + use_resource=True, + collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.GLOBAL_VARIABLES]) + + +def _set_fetches(result_tensor, train_op): + """Sets the fetches from the result tensor and training op.""" + + fetches = [] + if result_tensor is not None: + fetches.append(result_tensor) + if train_op is not None: + fetches.append(train_op) + if not fetches: + return None + return fetches class TensorTracer(object): @@ -94,16 +247,64 @@ class TensorTracer(object): @staticmethod def _match_next_flag(flags, pos): - """Returns the match for the next TensorTracer flag.""" + """Returns the match for the next TensorTracer flag. + + Args: + flags: a string that contains the flags. + pos: where in flags to start the search. + + Returns: + A pair where the first element is the regular-expression + match found and the second element indicates if the match + has a value. + """ match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) if match: - return match + return match, True match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) if match: - return match + return match, True match = _FLAG_NO_QUOTE_PAT.match(flags, pos) - return match + if match: + return match, True + match = _FLAG_NO_EQUAL_PAT.match(flags, pos) + if match: + # The flag is found but is not given a value. + return match, False + # The flag is not found. + return None, False + + @staticmethod + def validate_flag_names(): + """Validates if the TensorTrace flags passed are valid.""" + valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, + _FLAG_NAME_USE_COMPACT_TRACE, + _FLAG_NAME_SUBMODE, + _FLAG_NAME_EXCLUDED_OPNAMES, + _FLAG_NAME_EXCLUDED_OPTYPES, + _FLAG_NAME_INCLUDED_OPNAMES, + _FLAG_NAME_INCLUDED_OPTYPES, + _FLAG_NAME_TRACE_DIR, + _FLAG_NAME_REPORT_FILE, + _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, + _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, + _FLAG_NAME_OP_RANGE] + tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) + if not tensor_tracer_flags: + return + pos = 0 + while True: + match, _ = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + if not match: + break + flag_name = match.group(1) + if flag_name not in valid_flag_names: + raise ValueError( + 'The flag name "%s" passed via the environment variable "%s" ' + 'is invalid. Valid flag names are:' + '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names)) + pos = match.end() @staticmethod def print_flag_values(): @@ -117,11 +318,15 @@ class TensorTracer(object): result += 'Individual flag value:\n' pos = 0 while True: - match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + match, has_value = TensorTracer._match_next_flag( + tensor_tracer_flags, pos) if not match: break flag_name = match.group(1) - flag_value = match.group(2) + if has_value: + flag_value = match.group(2) + else: + flag_value = None result += ' %s: %s\n'%(flag_name, flag_value) pos = match.end() result += '\n' @@ -129,50 +334,92 @@ class TensorTracer(object): @staticmethod def get_flag_value(wanted_flag_name): - """Returns the value of a TensorTracer flags.""" + """Returns the value of a TensorTracer flags. + + Args: + wanted_flag_name: the name the the flag we are looking for. + + Returns: + A pair where the first element indicates if the flag is + found and the second element is the value of the flag. + + Raises: + RuntimeError: If supposedly deadcode is reached. + """ tensor_tracer_flags = os.getenv(_FLAGS_ENV_VAR) if not tensor_tracer_flags: - return '' + return False, None pos = 0 while True: - match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + match, has_value = TensorTracer._match_next_flag( + tensor_tracer_flags, pos) if not match: - return '' + return False, None flag_name = match.group(1) - flag_value = match.group(2) + if has_value: + flag_value = match.group(2) + else: + flag_value = None if flag_name == wanted_flag_name: - return flag_value + return True, flag_value pos = match.end() - return '' + raise RuntimeError('Should not reach here.') @staticmethod - def is_enabled(): - """Returns True if TensorTracer is enabled.""" + def flag_value_to_re_list(flag_name): + """Converts list of strings to compiled RE.""" + + re_list = [] + found, flag_value = TensorTracer.get_flag_value(flag_name) + if not found or not flag_value: + return re_list + list_of_values = flag_value.split() + for v in list_of_values: + r = re.compile(v) + re_list.append(r) + return re_list + + @staticmethod + def _is_flag_on(flag_name): + """Returns True if the given flag is on.""" - flag_value = TensorTracer.get_flag_value(_FLAG_NAME_ENABLE) + found, flag_value = TensorTracer.get_flag_value(flag_name) + if not found: + return False + if flag_value is None: + return True + # Depends on the flag value. flag_value = flag_value.lower() enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] return enabled + @staticmethod + def is_enabled(): + """Returns True if TensorTracer is enabled.""" + + return TensorTracer._is_flag_on(_FLAG_NAME_ENABLE) + @staticmethod def use_test_undeclared_outputs_dir(): - """Decides the output directory of the trace file. + """Decides the output directory of the report and trace files. Args: None. Returns: - True if the output trace file should be written to the + True if the output files should be written to the test-undeclared-outputs-directory defined via an env variable. """ - flag_value = TensorTracer.get_flag_value( + return TensorTracer._is_flag_on( _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) - flag_value = flag_value.lower() - enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] - return enabled + + @staticmethod + def use_compact_trace(): + return TensorTracer._is_flag_on( + _FLAG_NAME_USE_COMPACT_TRACE) @staticmethod def check_device_type(device_type): @@ -186,29 +433,80 @@ class TensorTracer(object): """Checks if the given trace mode is valid.""" valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, - _TRACE_MODE_FULL_TENSOR] + _TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM, + _TRACE_MODE_MAX_ABS] if trace_mode not in valid_trace_modes: raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' 'Valid trace modes are: %s'%(trace_mode, valid_trace_modes)) @staticmethod - def should_trace(device_type, op): - """Returns True if the given Op should be traced.""" + def check_submode(submode): + """Checks if the given submode is valid.""" + + if not submode: + return + valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF] + if submode not in valid_submodes: + raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.' + 'Valid submodes are: %s'%(submode, + valid_submodes)) + + @staticmethod + def unsafe_op(op): + """Returns True if this op is not safe to be traced.""" - if device_type != _DEVICE_TYPE_TPU: - raise ValueError('Non TPU device type is not supported') if control_flow_util.IsInCond(op): + return True + # Reasons for not including following op types: + # Assign: cause incorrect result with CPU tracing. + if op.type in ['Assign']: + return True + return False + + @staticmethod + def device_mismatch(device_type, op): + if device_type == _DEVICE_TYPE_TPU: + # pylint: disable=protected-access + return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr + # pylint: enable=protected-access + return False + + @staticmethod + def unsafe_scalar_trace(op): + """Return true if scalar output tensor from Op is not safe to be traced.""" + + # Tracing the following causes cycle in the graph on TPU. + if op.type in ['LoopCond', 'Enter', 'Merge', 'Const', + 'Switch', 'Less', 'ReadVariableOp']: + return True + # Tracing the following will cause casting-issue + # with the norm tracing mode or other compilation issues on CPU. + if op.type in ['VarHandleOp', 'IteratorToStringHandle', + 'IteratorGetNext', 'OneShotIterator', + 'IteratorV2', 'MakeIterator', + 'BatchDatasetV2', 'MapDataset', + 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', + 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']: + return True + return False + + @staticmethod + def less_interesting_op(op): + """Returns True if the given Op is not an interesting one to be traced.""" + + found, _ = TensorTracer.get_flag_value( + _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) + if found: + # users force to include all ops. return False - if op.type in ['Reshape', 'ArgMin', 'ArgMax']: - return False - # pylint: disable=protected-access - return tpu._TPU_REPLICATE_ATTR in op.node_def.attr - # pylint: enable=protected-access + # Following ops are highly unlikey to cause bugs. + return op.type in ['Const', 'Identity', 'Cast', 'Shape'] @staticmethod def reason(op_idx, details): - """Returns why the Op at op_idx is traced or not.""" + """Returns reason why the Op at op_idx is traced or not.""" + return '%d %s'%(op_idx, details) @staticmethod @@ -253,7 +551,7 @@ class TensorTracer(object): temporarily_marked_ops, sorted_ops) # pylint: disable=protected-access for ctrl_output_op in op._control_outputs: - # pylint: enable=protected-access + # pylint: enable=protected-access visit(ctrl_output_op, cycle, permanently_marked_ops, temporarily_marked_ops, sorted_ops) temporarily_marked_ops.remove(op) @@ -274,6 +572,33 @@ class TensorTracer(object): assert len(unsorted_ops) == len(sorted_ops) return (True, sorted_ops) + @staticmethod + def _make_op_and_tensor_maps(op_list): + """Creates various maps and lists from op_list. + + Args: + op_list: a list of Ops + + Returns: + opname_idx_map: a map from Op's name to its index in op_list. + tensor_list: a list of output tensors of the Ops in op_list. + tensorname_idx_map: a map from output tensor name to its index + in tensor_list. + """ + + opname_idx_map = {} + tensor_list = [] + tensorname_idx_map = {} + for op_id, op in enumerate(op_list): + if op.name in opname_idx_map: + raise ValueError('Duplicated Op name: %s'%op.name) + opname_idx_map[op.name] = op_id + for output_tensor in op.outputs: + if output_tensor.name not in tensorname_idx_map: + tensor_list.append(output_tensor) + tensorname_idx_map[output_tensor.name] = len(tensor_list)-1 + return (opname_idx_map, tensor_list, tensorname_idx_map) + def __init__(self): """Initializes a TensorTracer. @@ -281,28 +606,36 @@ class TensorTracer(object): """ self._version = 'use-outside-compilation' self._device_type = None - self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) - if not self._trace_mode: + TensorTracer.validate_flag_names() + found, self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) + if not found or not self._trace_mode: self._trace_mode = _TRACE_MODE_NAN_INF TensorTracer.check_trace_mode(self._trace_mode) + found, self._submode = TensorTracer.get_flag_value(_FLAG_NAME_SUBMODE) + if not found or not self._submode: + self._submode = _SUBMODE_DETAILED + TensorTracer.check_submode(self._submode) self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE self._instrument_records = {} - interesting_ops = TensorTracer.get_flag_value(_FLAG_NAME_INTERESTING_OPS) - self._selected_ops = interesting_ops.split() - self._set_trace_file_path() + self._set_trace_dir() + self._set_report_file() self._set_op_range() + self._set_excluded_opnames() + self._set_excluded_optypes() + self._set_included_opnames() + self._set_included_optypes() self._num_replicas = None + self._num_replicas_per_host = None + self._num_hosts = None self._replica_id = None - def _add_replica_id_to_graph(self, num_replicas, result_tensor): + def _add_replica_id_to_graph(self, result_tensor): """Adds nodes for computing the replica ID to the graph.""" - if not num_replicas: + if not self._num_replicas: self._replica_id = 'unknown' return result_tensor - self._num_replicas = num_replicas - with ops.control_dependencies(None): # Uses None as dependency to run outside of TPU graph rewrites. self._replica_id = tpu_ops.tpu_replicated_input( @@ -314,27 +647,47 @@ class TensorTracer(object): # the replica_id to ensure that replica_id will be added to the graph. return array_ops.identity(result_tensor) - def _set_trace_file_path(self): - """Sets the path of the output trace file.""" - - self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE) - if not self._trace_file_path: - raise ValueError('--%s is not set in the environment variable %s' - %(_FLAG_NAME_TRACE_FILE, _FLAGS_ENV_VAR)) - elif TensorTracer.use_test_undeclared_outputs_dir(): - if os.path.isabs(self._trace_file_path): + def _set_trace_dir(self): + found, self._trace_dir = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_DIR) + if found and self._trace_dir \ + and TensorTracer.use_test_undeclared_outputs_dir(): + raise ValueError('Cannot not use --%s and --%s at the same time' + %(_FLAG_NAME_TRACE_DIR, + _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)) + if TensorTracer.use_test_undeclared_outputs_dir(): + self._trace_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) + + def _set_report_file(self): + """Sets the path of the output report file.""" + + found, self._report_file_path = TensorTracer.get_flag_value( + _FLAG_NAME_REPORT_FILE) + if found and self._report_file_path \ + and TensorTracer.use_test_undeclared_outputs_dir(): + if os.path.isabs(self._report_file_path): raise ValueError('If use_test_undeclared_outputs_dir is set,' - 'trace_file_path cannot be an absolute path (%s)' - %self._trace_file_path) + 'report_file_path cannot be an absolute path (%s)' + %self._report_file_path) outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) - self._trace_file_path = os.path.join(outputs_dir, - self._trace_file_path) + self._report_file_path = os.path.join(outputs_dir, + self._report_file_path) + if not self._report_file_path: + self._report_file = None + return + try: + self._report_file = gfile.Open(self._report_file_path, 'w') + except IOError as e: + raise e + + def _close_report_file(self): + if self._report_file: + self._report_file.close() def _set_op_range(self): """Sets the index range of the Ops that we will consider tracing.""" - op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE) - if not op_range: + found, op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE) + if not found or not op_range: self._op_range = (-1, -1) # this means including all ops. return match = _OP_RANGE_PAT.match(op_range) @@ -350,20 +703,68 @@ class TensorTracer(object): return False return self._op_range[1] < 0 or idx <= self._op_range[1] - def _write_report(self, content): - """Writes the given content to the report.""" + def _set_excluded_opnames(self): + self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_EXCLUDED_OPNAMES) + + def _set_excluded_optypes(self): + self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_EXCLUDED_OPTYPES) + + def _set_included_opnames(self): + self._included_opname_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_INCLUDED_OPNAMES) + + def _set_included_optypes(self): + self._included_optype_re_list = TensorTracer.flag_value_to_re_list( + _FLAG_NAME_INCLUDED_OPTYPES) + + def _is_user_included_op(self, op): + for opname_re in self._included_opname_re_list: + if opname_re.match(op.name): + return True + for optype_re in self._included_optype_re_list: + if optype_re.match(op.type): + return True + return False - logging.info('%s %s'%(_TRACER_LOG_PREFIX, content)) + def _is_user_excluded_op(self, op): + for opname_re in self._excluded_opname_re_list: + if opname_re.match(op.name): + return True + for optype_re in self._excluded_optype_re_list: + if optype_re.match(op.type): + return True + return False - def _is_selected_op(self, op_name): - """Returns True if the Op with op_name is selected to be traced.""" + def _use_tensor_values_cache(self): + """Returns True if immediate tensors should be first saved to a cache.""" - if not self._selected_ops: + if self._trace_mode not in set([_TRACE_MODE_NAN_INF, + _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS]): + return False + if self._trace_dir and _trace_files_need_precreated(self._trace_dir): return True - if op_name in self._selected_ops: + if TensorTracer.use_compact_trace(): return True return False + def _save_tensor_value_to_cache_op(self, graph, cache_idx, updates): + """Returns an Op that will save the given updates to an entry in the cache.""" + + cache = _get_tensor_values_cache(graph) + indices = constant_op.constant([cache_idx]) + return state_ops.scatter_update(cache, indices, updates).op + + def _write_report(self, content): + """Writes the given content to the report.""" + + line = '%s %s'%(_TRACER_LOG_PREFIX, content) + if self._report_file: + self._report_file.write(line) + else: + logging.info(line) + def _write_config_section(self): """Writes the config section of the report.""" @@ -371,7 +772,11 @@ class TensorTracer(object): self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version)) self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type)) self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode)) + self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE, self._submode)) self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, self._num_replicas)) + self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST, + self._num_replicas_per_host)) + self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, self._num_hosts)) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG)) def _write_reason_section(self): @@ -388,9 +793,50 @@ class TensorTracer(object): self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) for i in range(0, len(op_list)): - self._write_report('%d "%s" %s\n'%(i, op_list[i].name, op_list[i].type)) + op = op_list[i] + line = '%d "%s" %s'%(i, op.name, op.type) + for out_tensor in op.outputs: + if out_tensor.name not in self._tensorname_idx_map: + raise ValueError( + 'out_tensor %s is not in tensorname_idx_map'%out_tensor.name) + line += ' %d'%self._tensorname_idx_map[out_tensor.name] + line += '\n' + self._write_report(line) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) + def _write_tensor_list_section(self, tensor_list, opname_idx_map): + """Writes the tensor-list section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, + _SECTION_NAME_TENSOR_LIST)) + self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list))) + for i in range(0, len(tensor_list)): + tensor = tensor_list[i] + line = '%d "%s"'%(i, tensor.name) + for consumer_op in tensor.consumers(): + if consumer_op.name not in opname_idx_map: + raise ValueError( + 'consumer_op %s is not in opname_idx_map'%consumer_op.name) + line += ' %d'%opname_idx_map[consumer_op.name] + line += '\n' + self._write_report(line) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, + _SECTION_NAME_TENSOR_LIST)) + + def _write_cache_index_map_section(self): + """Writes the mapping from cache index to tensor index to the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, + _SECTION_NAME_CACHE_INDEX_MAP)) + self._write_report('%s %d\n'%(_FIELD_NAME_NUM_CACHE_INDICES, + len(self._cache_idx_to_tensor_idx))) + for cache_idx in range(0, len(self._cache_idx_to_tensor_idx)): + tensor_idx = self._cache_idx_to_tensor_idx[cache_idx] + line = '%d %d\n'%(cache_idx, tensor_idx) + self._write_report(line) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, + _SECTION_NAME_CACHE_INDEX_MAP)) + def _write_graph_section(self, succeed, sorted_or_cycle): """Writes the graph section of the report.""" @@ -402,12 +848,67 @@ class TensorTracer(object): self._write_report('%d "%s"\n'%(i, l[i].name)) self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) - def _make_tensor_trace_fun(self, op_name, output_idx): + def _preprocess_traced_tensor(self, tensor): + """Computes NAN/Norm/Max on TPUs before sending to CPU. + + Args: + tensor: The tensor to be traced. + Returns: + A tensor that should be input to the trace_function. + Raises: + RuntimeError: If the trace mode is invalid. + """ + + def _detect_nan_inf(tensor): + """Trace function for detecting any NaN/Inf in the tensor.""" + + if tensor.dtype.is_floating: + mask = math_ops.reduce_any( + gen_math_ops.logical_or( + gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor))) + output_tensor = control_flow_ops.cond(mask, + lambda: constant_op.constant(1.0), + lambda: constant_op.constant(0.0)) + else: + output_tensor = constant_op.constant(0.0) + # The shape has to be 1. Set it if it does not have the information. + output_tensor = array_ops.reshape(output_tensor, [1]) + return output_tensor + + def _show_norm(tensor): + tensor = math_ops.cast(tensor, dtypes.float32) + output_tensor = linalg_ops.norm(tensor) + # The shape has to be 1. Set it if it does not have the information. + output_tensor = array_ops.reshape(output_tensor, [1]) + return output_tensor + + def _show_max_abs(tensor): + tensor = math_ops.cast(tensor, dtypes.float32) + output_tensor = math_ops.reduce_max(math_ops.abs(tensor)) + zero = constant_op.constant(0, dtypes.float32) + output_tensor = gen_math_ops.maximum(zero, output_tensor) + # The shape has to be 1. Set it if it does not have the information. + output_tensor = array_ops.reshape(output_tensor, [1]) + return output_tensor + + if self._trace_mode == _TRACE_MODE_NAN_INF: + return _detect_nan_inf(tensor) + if self._trace_mode == _TRACE_MODE_PART_TENSOR: + return tensor + if self._trace_mode == _TRACE_MODE_FULL_TENSOR: + return tensor + if self._trace_mode == _TRACE_MODE_NORM: + return _show_norm(tensor) + if self._trace_mode == _TRACE_MODE_MAX_ABS: + return _show_max_abs(tensor) + raise RuntimeError( + 'Tensor trace fun for %s is not yet implemented' % self._trace_mode) + + def _make_tensor_trace_fun(self, tensor_name): """Makes the tensor tracing function called by outside compilation. Args: - op_name: the name of the Op that outputs the tensor to be traced. - output_idx: which output of the Op it is (0 means the first output). + tensor_name: name of the tensor being traced. Returns: A function to be passed as the first argument to outside compilation. @@ -416,77 +917,414 @@ class TensorTracer(object): RuntimeError: If the trace mode is invalid. """ - def _print_tensor(op_name, output_idx, num_elements, tensor, output_tensor): + def _print_tensor(tensor_name, num_elements, tensor, output_tensor): """Prints a tensor value to a file. Args: - op_name: the name of the Op that outputs the tensor to be printed. - output_idx: which output of the Op it is (0 means the first output). - num_elements: number of elements to print. + tensor_name: name of the tensor being traced. + num_elements: number of elements to print (-1 means print all). tensor: the tensor needs to be returned. output_tensor: the tensor needs to be printed. Returns: The same tensor passed via the "tensor" argument. + + Raises: + ValueError: If tensor_name is not already in + self._tensorname_idx_map. """ - msg = '"%s:%d" '%(op_name, output_idx) - output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + + if self._submode == _SUBMODE_BRIEF: + if tensor_name not in self._tensorname_idx_map: + raise ValueError( + 'Tensor name %s is not in the tensorname_idx_map'%tensor_name) + msg = '%d'%self._tensorname_idx_map[tensor_name] + else: + msg = '"%s"'%tensor_name + + if self._trace_dir: + output_path = os.path.join(self._trace_dir, _TRACE_FILE_NAME) + output_stream = _OUTPUT_STREAM_ESCAPE + output_path + else: + output_stream = sys.stderr print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), - ' @', self._replica_id, - '\n', output_tensor, + '@', self._replica_id, + '\n', output_tensor, '\n', summarize=num_elements, output_stream=output_stream) with ops.control_dependencies([print_op]): return array_ops.identity(tensor).op - def _detect_nan_inf(tensor): - """Trace function for detecting any NaN/Inf in the tensor.""" - - if tensor.dtype.is_floating: - # Since host can't handle bf16, always convert tensor to f32. - tensor = math_ops.cast(tensor, dtypes.float32) - output_tensor = math_ops.reduce_any( - gen_math_ops.logical_or(gen_math_ops.is_nan(tensor), - gen_math_ops.is_inf(tensor))) - else: - output_tensor = constant_op.constant(0) - return _print_tensor(op_name, output_idx, 1, tensor, output_tensor) - - def _show_global_step(tensor): - """Trace function for printing the global step count.""" - - return _print_tensor(op_name, output_idx, 1, tensor, tensor) def _show_part_tensor(tensor): """Trace function for printing part of the tensor.""" - return _print_tensor(op_name, output_idx, self._part_tensor_size, + return _print_tensor(tensor_name, self._part_tensor_size, tensor, tensor) def _show_full_tensor(tensor): """Trace function for printing the entire tensor.""" - return _print_tensor(op_name, output_idx, -1, tensor, tensor) + return _print_tensor(tensor_name, -1, tensor, tensor) - if op_name == _GLOBAL_STEP_OP_NAME: - return _show_global_step - if self._trace_mode == _TRACE_MODE_NAN_INF: - return _detect_nan_inf if self._trace_mode == _TRACE_MODE_PART_TENSOR: return _show_part_tensor - if self._trace_mode == _TRACE_MODE_FULL_TENSOR: + # The input tensor has a shape of "[1]" for _TRACE_MODE_NAN_INF, + # _TRACE_MODE_NORM, and _TRACE_MODE_MAX_ABS, as related computations are + # performed within TPUs and only their results are transferred to CPU. + # Simply, print the full tensor for these trace modes. + if self._trace_mode in [ + _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_FULL_TENSOR, + _TRACE_MODE_MAX_ABS + ]: return _show_full_tensor raise RuntimeError('Tensor trace fun for %s is not yet implemented' %self._trace_mode) - def trace_tpu(self, graph, result_tensor, num_replicas=None): - """Traces the tensors generated by TPU Ops in a TF graph. + def _skip_op(self, op_id, op, user_included, user_excluded, + in_exec_path=True): + """Returns True if we should not trace Op.""" + + if user_included: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_USER_INCLUDED) + return False + if user_excluded: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_USER_EXCLUDED) + return True + if not in_exec_path: + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_NOT_EXECUTED) + return True + if not self._inside_op_range(op_id): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_OUTSIDE_OP_RANGE) + return True + if TensorTracer.unsafe_op(op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_UNSAFE_OP) + return True + if TensorTracer.device_mismatch(self._device_type, op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_DEVICE_MISMATCH) + return True + if TensorTracer.less_interesting_op(op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _REASON_LESS_INTERESTING_OP) + return True + return False + + def _skip_tensor(self, op_id, out_tensor, user_included, + user_excluded): + """Returns True if we should not trace out_tensor.""" + + # Skips a tensor if the tensor has a non-numeric type. + # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) + # because it also excludes tensors with dtypes, bool, and + # float32_ref, which we actually want to trace. + non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, + dtypes.string]) + if out_tensor.dtype in non_numeric_tensor_types: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_NON_NUMERIC_TENSOR) + return True + + if user_included: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_USER_INCLUDED) + return False + if user_excluded: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_USER_EXCLUDED) + return True + if not out_tensor.get_shape().is_fully_defined(): + # If trace mode is nan-inf, norm or max, then the tensor will be reduced + # to a scalar before the outside compilation call. + if self._trace_mode in [ + _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS + ]: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_TENSOR_GET_TRACED) + return False + else: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_DYNAMIC_SHAPE) + return True + rank = len(out_tensor.shape) + if rank < 1: + # scalar + if TensorTracer.unsafe_scalar_trace(out_tensor.op): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_UNSAFE_SCALAR) + return True + else: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_SCALAR_GET_TRACED) + return False + else: + # tensor + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _REASON_TENSOR_GET_TRACED) + return False + + def _filter_execution_path_operations(self, operations, fetches): + """Returns the set of ops in the execution path to compute given fetches.""" + + # If no fetch provided, then return all operations. + if fetches is None: + return set(operations) + # Convert to list, if a single element is provided. + if not isinstance(fetches, (list, tuple)): + fetches = [fetches] + # If a tensor is given as fetch, convert it to op. + op_fetches = [] + for fetch in fetches: + if isinstance(fetch, ops.Operation): + op_fetches.append(fetch) + elif isinstance(fetch, ops.Tensor): + op_fetches.append(fetch.op) + else: + raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' + %fetch) + + execution_path_operations = set(op_fetches) + traverse_stack = list(op_fetches) + while True: + if not traverse_stack: + break + head_op = traverse_stack.pop() + input_ops = [tensor_input.op for tensor_input in head_op.inputs] + input_ops.extend(head_op.control_inputs) + + for input_op in input_ops: + if input_op not in execution_path_operations: + execution_path_operations.add(input_op) + traverse_stack.append(input_op) + return execution_path_operations + + def _determine_traced_tensors(self, graph, fetches): + """Determines the tensors that will be traced.""" + + self._traced_tensorname_to_cache_idx_map = {} + self._cache_idx_to_tensor_idx = [] + operations = graph.get_operations() + # Filter out the operations that won't be executed. + # if fetches=None, then ops_in_exec_path = set(operations) + ops_in_exec_path = self._filter_execution_path_operations(operations, + fetches) + checkpoint_operations = self._get_checkpoints(graph) + for op_id, op in enumerate(operations): + if checkpoint_operations and op.name not in checkpoint_operations: + continue + user_included = self._is_user_included_op(op) + user_excluded = self._is_user_excluded_op(op) + in_exec_path = op in ops_in_exec_path + if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path): + continue + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + if self._skip_tensor(op_id, out_tensor, user_included, + user_excluded): + continue + tensor_name = out_tensor.name + if tensor_name in self._traced_tensorname_to_cache_idx_map: + raise ValueError( + 'Tensor name %s should not be already in ' + 'traced_tensorname_to_cache_idx_map'%tensor_name) + if tensor_name not in self._tensorname_idx_map: + raise ValueError( + 'Tensor name %s is not in the tensorname_idx_map'%tensor_name) + tensor_idx = self._tensorname_idx_map[tensor_name] + cache_idx = len(self._traced_tensorname_to_cache_idx_map) + self._traced_tensorname_to_cache_idx_map[tensor_name] = cache_idx + self._cache_idx_to_tensor_idx.append(tensor_idx) + if len(self._traced_tensorname_to_cache_idx_map) != len( + self._cache_idx_to_tensor_idx): + raise RuntimeError('len(self._traced_tensorname_to_cache_idx_map) != ' + 'len(self._cache_idx_to_tensor_idx') + + def _check_trace_files(self): + """Checks if any requirements for trace files are satisfied.""" + + if not self._trace_dir: + # traces will be written to stderr. No need to check trace files. + return + if _trace_files_need_precreated(self._trace_dir): + for replica_id in range(0, self._num_replicas): + trace_file_path = os.path.join( + self._trace_dir, + _COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id + if not gfile.Exists(trace_file_path): + raise RuntimeError( + '%s must be pre-created with the ' + 'appropriate properties.'%trace_file_path) + else: + if not gfile.Exists(self._trace_dir): + gfile.MkDir(self._trace_dir) + if not gfile.Exists(self._trace_dir): + raise RuntimeError('Failed to create %s'%self._trace_dir) + + def _pre_tracing(self, graph, fetches): + """Work needs to be done prior to TPU or CPU tracing.""" + + self._check_trace_files() + operations = graph.get_operations() + (opname_idx_map, tensor_list, self._tensorname_idx_map) = ( + TensorTracer._make_op_and_tensor_maps(operations)) + self._write_config_section() + self._write_op_list_section(operations) + self._write_tensor_list_section(tensor_list, opname_idx_map) + self._determine_traced_tensors(graph, fetches) + self._write_cache_index_map_section() + # Does the topological sort before adding any nodes to the graph. + (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + if self._use_tensor_values_cache(): + _create_tensor_values_cache(graph, + len(self._cache_idx_to_tensor_idx)) + return (operations, succeed, sorted_or_cycle) + + def _post_tracing(self, succeed, sorted_or_cycle): + """Work needs to be done after TPU or CPU tracing.""" + + self._write_reason_section() + self._write_graph_section(succeed, sorted_or_cycle) + self._close_report_file() + + def _get_checkpoints(self, graph): + """Returns the list of Ops that produce the tensors traced with API. Args: graph: the graph of Ops. + + Returns: + A set of operation names which should be traced. + """ + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, + _TENSOR_TRACER_CHECKPOINT)) + checkpoint_operations = set() + tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION) + for (tensor, checkpoint_name) in tensor_tracer_variables: + self._write_report('%s %s\n'%(tensor.name, checkpoint_name)) + checkpoint_operations.add(tensor.op.name) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, + _TENSOR_TRACER_CHECKPOINT)) + return checkpoint_operations + + def _generate_flush_cache_op(self, graph, start_replica, on_tpu): + """Generates an Op that will flush the cache to file. + + Args: + graph: the graph of Ops + start_replica: the ID of the first replica being flushed by this Op. + on_tpu: if the graph is executed on TPU. + + Returns: + The Op to flush the cache to file. + """ + def _make_flush_fun(replica_id): + """Makes a function for flushing the cache for the given replica.""" + + def _fun(): + """A function that flushes the cache to a file.""" + + def _flush_fun(cache): + """Flushes the cache to a file.""" + + if isinstance(replica_id, str): + replica_id_str = replica_id + else: + replica_id_str = '%d'%replica_id + output_path = os.path.join(self._trace_dir, + _COMPACT_TRACE_FILE_PREFIX) \ + + replica_id_str + output_stream = _OUTPUT_STREAM_ESCAPE + output_path + new_step_line = _REPLICA_ID_TAG + replica_id_str + print_op = logging_ops.print_v2( + new_step_line, '\n', + cache, '\n', + summarize=-1, + output_stream=output_stream) + with ops.control_dependencies([print_op]): + return constant_op.constant(0).op + + cache = _get_tensor_values_cache(graph) + if on_tpu: + flush_op = tpu.outside_compilation(_flush_fun, cache.value()) + else: + flush_op = _flush_fun(cache.value()) + with ops.control_dependencies([flush_op]): + reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, + dtype=cache.dtype, + shape=cache.shape) + assign_op = state_ops.assign(cache, reset_value).op + with ops.control_dependencies([assign_op]): + return flush_op.outputs[0] + + return _fun + + def _f(replica_id): + return _make_flush_fun(replica_id) + def _eq(x): + return math_ops.equal(x, self._replica_id) + def _do_nothing(): + return constant_op.constant(0) + + return control_flow_ops.case({\ + _eq(start_replica): _f(start_replica), \ + _eq(start_replica+1): _f(start_replica+1), \ + _eq(start_replica+2): _f(start_replica+2), \ + _eq(start_replica+3): _f(start_replica+3), \ + _eq(start_replica+4): _f(start_replica+4), \ + _eq(start_replica+5): _f(start_replica+5), \ + _eq(start_replica+6): _f(start_replica+6), \ + _eq(start_replica+7): _f(start_replica+7), \ + }, + default=_do_nothing, + exclusive=True).op + + def _flush_tensor_values_cache(self, graph, result_tensor, train_op, on_tpu): + """Flushes the intermediate tensor values in the graph to the cache. + + Args: + graph: the graph of Ops + result_tensor: a result tensor of evaluating the graph. + train_op: the training op. + on_tpu: if the graph is executed on TPU. + + Returns: + An identical copy of result tensor. + """ + + train_op_list = [] + if train_op is not None: + train_op_list.append(train_op) + with ops.control_dependencies(train_op_list): + flush_cache_op_list = [] + for host in range(self._num_hosts): + start_replica = host * 8 + flush_op = self._generate_flush_cache_op(graph, start_replica, on_tpu) + flush_cache_op_list.append(flush_op) + with ops.control_dependencies(flush_cache_op_list): + return array_ops.identity(result_tensor) + + def trace_tpu(self, graph, + result_tensor, + train_op, + num_replicas=None, + num_replicas_per_host=None, + num_hosts=None): + """Traces the tensors generated by TPU Ops in a TF graph. + + Args: + graph: the graph of Ops executed on the TPU. result_tensor: a result tensor of evaluating the graph. + train_op: the training op. num_replicas: number of replicas used on the TPU. + num_replicas_per_host: number of replicas per TPU host. + num_hosts: total number of TPU hosts. Returns: A tuple (result_tensor_copy, tracing_ops), where: @@ -496,58 +1334,148 @@ class TensorTracer(object): should pose control dependencies upon these Ops so that they will be executed when the graph is evaluated. + + Raises: + RuntimeError: If num_replicas_per_host > 8. """ + def _cast_unsupported_dtypes(tensor): + """Casts tensor to a supported type.""" + + if tensor.dtype.__eq__(dtypes.int64): + # outside-compilation doesn't support int64 input yet. + return math_ops.cast(tensor, dtypes.int32) + if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( + dtypes.float16): + # Since host can't handle bf16, convert tensor to f32. + return math_ops.cast(tensor, dtypes.float32) + return tensor + self._device_type = _DEVICE_TYPE_TPU + self._num_replicas = num_replicas + self._num_replicas_per_host = num_replicas_per_host + self._num_hosts = num_hosts + if self._num_replicas_per_host > 8: + # Checks for the assumption in _generate_flush_cache_op(). + raise RuntimeError( + 'num_replicas_per_host (%d) is ' + 'greater than 8'%self._num_replicas_per_host) + TensorTracer.check_device_type(self._device_type) - result_tensor_copy = self._add_replica_id_to_graph(num_replicas, - result_tensor) - self._write_config_section() + result_tensor_copy = self._add_replica_id_to_graph(result_tensor) + fetches = _set_fetches(result_tensor, train_op) + (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph, fetches) + tracing_ops = [] - operations = graph.get_operations() - self._write_op_list_section(operations) - # Does the topological sort before adding any nodes to the graph. - (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) - for op_id, op in enumerate(operations): - if not self._inside_op_range(op_id): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_OUTSIDE_OP_RANGE) - continue - if not TensorTracer.should_trace(self._device_type, op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_SHOULD_NOT_TRACE) - continue - if not self._is_selected_op(op.name): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _RECORD_FILTERED_OUT) - continue + for op in operations: for i in range(len(op.outputs)): out_tensor = op.outputs[i] - if not out_tensor.get_shape().is_fully_defined(): - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_DYNAMIC_SHAPE) - continue # cannot trace tensors with dynamic shape. - rank = len(out_tensor.shape) - if rank < 1: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_SCALAR) - continue # cannot trace scalar. - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _RECORD_GET_TRACED) + tensor_name = out_tensor.name + if tensor_name not in self._traced_tensorname_to_cache_idx_map: + continue + # Create the list of consumers before calling _preprocess_traced_tensor. + # Otherwise, adding control input below, will introduce a cycle in the + # graph. + consumers = out_tensor.consumers() + if not consumers: + continue + processed_out_tensor = self._preprocess_traced_tensor(out_tensor) + processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor) + if self._use_tensor_values_cache(): + cache_idx = self._traced_tensorname_to_cache_idx_map[tensor_name] + trace_op = self._save_tensor_value_to_cache_op(graph, + cache_idx, + processed_out_tensor) + else: + trace_op = tpu.outside_compilation( + self._make_tensor_trace_fun(tensor_name), processed_out_tensor) + for consumer_op in consumers: + # pylint: disable=protected-access + consumer_op._add_control_input(trace_op) + # pylint: enable=protected-access + if self._use_tensor_values_cache(): + result_tensor_final = self._flush_tensor_values_cache(graph, + result_tensor_copy, + train_op, + on_tpu=True) + else: + result_tensor_final = result_tensor_copy + self._post_tracing(succeed, sorted_or_cycle) + return (result_tensor_final, tracing_ops) + + def _generate_cpu_result(self, result_tensor, train_op, graph): + """Generates the final CPU result.""" + + if self._use_tensor_values_cache(): + result_tensor_final = self._flush_tensor_values_cache(graph, + result_tensor, + train_op, + on_tpu=False) + else: + result_tensor_final = array_ops.identity(result_tensor) + return result_tensor_final + + def trace_cpu(self, graph, result_tensor, train_op): + """Traces the tensors generated by CPU Ops in a TF graph. + + Args: + graph: the graph of Ops executed on the CPU. + result_tensor: a result tensor of evaluating the graph. + train_op: the training op. + + Returns: + A pair (final_result_tensor, tracing_calls) where: + final_result_tensor: an identical copy of result_tensor. + tracing_calls: a map from keys to trace calls. + A key is constructed from an Op's name. + A trace call consists of a function and a tensor ( + the function will be invoked with the tensor). + """ + + if result_tensor is None: + raise ValueError( + 'The result_tensor passed to trace_cpu should not be None') + + self._device_type = _DEVICE_TYPE_CPU + TensorTracer.check_device_type(self._device_type) + self._num_replicas = 1 + self._num_replicas_per_host = 1 + self._num_hosts = 1 + self._replica_id = 0 + fetches = _set_fetches(result_tensor, train_op) + (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph, fetches) + + tracing_calls = {} + for op in operations: + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + tensor_name = out_tensor.name + if tensor_name not in self._traced_tensorname_to_cache_idx_map: + continue + # Create the list of consumers before calling _preprocess_traced_tensor. + # Otherwise, adding control input below, will introduce a cycle in the + # graph. consumers = out_tensor.consumers() - trace_op = tpu.outside_compilation( - self._make_tensor_trace_fun(op.name, i), out_tensor) - if consumers: + if not consumers: + continue + processed_out_tensor = self._preprocess_traced_tensor(out_tensor) + if self._use_tensor_values_cache(): + cache_idx = self._traced_tensorname_to_cache_idx_map[tensor_name] + trace_op = self._save_tensor_value_to_cache_op(graph, + cache_idx, + processed_out_tensor) for consumer_op in consumers: # pylint: disable=protected-access consumer_op._add_control_input(trace_op) # pylint: enable=protected-access else: - # if there is no consumer, we will add the control dependence later - # when we add the control dependency to the output operations. - tracing_ops.append(trace_op) - - self._write_reason_section() - self._write_graph_section(succeed, sorted_or_cycle) - - return (result_tensor_copy, tracing_ops) + trace_fun = self._make_tensor_trace_fun(tensor_name) + trace_call = (trace_fun, [processed_out_tensor]) + trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i) + tracing_calls[trace_call_key] = trace_call + + self._post_tracing(succeed, sorted_or_cycle) + final_result_tensor = self._generate_cpu_result(result_tensor, + train_op, + graph) + return (final_result_tensor, tracing_calls) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index def57da20d6018dcf27ccb7a9d04592f38ce2f7c..de2bfd49eca50c87dc506d9aa690d49c8da20460 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -19,23 +19,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.compiler import xla from tensorflow.contrib.framework.python.framework import experimental +from tensorflow.contrib.tpu.proto import dynamic_padding_pb2 as dynamic_padding from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.compat import compat as api_compat from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat +from tensorflow.python.util import nest # Operations that indicate some error in the users graph, e.g. a placeholder @@ -322,6 +328,30 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): def HostComputeCore(self): return self._host_compute_core + def _RemoveExternalControlEdges(self, op): + """Remove any external control dependency on this op.""" + internal_control_inputs = [] + external_control_inputs = [] + for x in op.control_inputs: + # pylint: disable=protected-access + is_internal_op = False + ctxt = x._get_control_flow_context() + while ctxt is not None: + if ctxt == self: + is_internal_op = True + break + ctxt = ctxt._outer_context + if is_internal_op: + internal_control_inputs.append(x) + else: + external_control_inputs.append(x) + # pylint: enable=protected-access + # pylint: disable=protected-access + op._remove_all_control_inputs() + op._add_control_inputs(internal_control_inputs) + # pylint: enable=protected-access + return internal_control_inputs, external_control_inputs + def AddOp(self, op): # pylint: disable=protected-access if op.type in _BLACKLISTED_OPS: @@ -372,11 +402,14 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] + with ops.control_dependencies(None): + self.Enter() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] + self.Exit() # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access @@ -480,14 +513,19 @@ def replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, - name=None): + name=None, + maximum_shapes=None): """Builds a graph operator that runs a replicated TPU computation. Args: computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. + have the same number of inputs. Each input can be a nested structure + containing values that are convertible to tensors. Note that passing an + N-dimension list of compatible values will result in a N-dimention list of + scalar tensors rather than a single Rank-N tensors. If you need different + behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the @@ -497,15 +535,125 @@ def replicate(computation, only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. + maximum_shapes: A nested structure of tf.TensorShape representing the shape + to which the respective component of each input element in each replica + should be padded. Any unknown dimensions (e.g. tf.Dimension(None) in a + tf.TensorShape or -1 in a tensor-like object) will be padded to the + maximum size of that dimension over all replicas. Note that if the input + dimension is already static, we won't do padding on it and we require the + maximum_shapes to have the same value or None on that dimension. The + structure of `maximum_shapes` needs to be the same as `inputs[0]`. Returns: - A list of lists of output tensors, indexed by `[replica_num][output_num]`. + A list of outputs, indexed by `[replica_num]` each output can be a nested + structure same as what computation() returns with a few exceptions. + + Exceptions include: + 1) None output: a NoOp would be returned which control-depends on + computation. + 2) Single value output: A tuple containing the value would be returned. + 3) Operation-only outputs: a NoOp would be returned which + control-depends on computation. + TODO(b/121383831): Investigate into removing these special cases. + Raises: ValueError: If all replicas do not have equal numbers of input tensors. ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. + ValueError: If the static `inputs` dimensions don't match with the values + given in `maximum_shapes`. + ValueError: If the structure of inputs per replica does not match + the structure of `maximum_shapes`. """ - return split_compile_and_replicate(computation, inputs, infeed_queue, - device_assignment, name)[1] + return split_compile_and_replicate( + computation, + inputs, + infeed_queue, + device_assignment, + name, + maximum_shapes=maximum_shapes)[1] + + +def _pad_all_input(inputs, padded_shapes): + """Pad all input tensors given padded_shapes. + + The real shape tensors will be concatenated with the padded original inputs. + + Args: + inputs: The original inputs. + padded_shapes: A list of padded shapes for each input. + + Returns: + The padded inputs and a PaddingMap list which maps the padded input + dimension to the real shape argument index. + """ + input_shape_tensors = [] + for core_idx, inputs_per_core in enumerate(inputs): + for idx, input_tensor in enumerate(inputs_per_core): + if core_idx == 0: + input_shape_tensors.append([]) + input_shape_tensors[idx].append(array_ops.shape(input_tensor)) + + maximum_shapes = [] + for shapes_per_input in input_shape_tensors: + maximum_shapes.append( + math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0)) + + padded_inputs = [] + real_shapes = [] + padding_maps = [] + for core_idx, inputs_per_core in enumerate(inputs): + padded_inputs.append([]) + real_shapes.append([]) + real_shape_idx = len(inputs_per_core) - 1 + for idx, input_tensor in enumerate(inputs_per_core): + input_shape_tensor = input_shape_tensors[idx][core_idx] + input_shape = input_tensor.get_shape() + padded_shape = padded_shapes[idx] + + # The static shape of inputs should be compatible with the given padded + # shapes. + input_shape.assert_is_compatible_with(padded_shape) + + if input_shape.is_fully_defined(): + # Do nothing if the shape of the whole tensor is already static. + padded_inputs[core_idx].append(input_tensor) + else: + # Only pad the non static shape dimension. + for i, s in enumerate(input_shape): + if s.value is None: + if core_idx == 0: + real_shape_idx += 1 + padding_map = dynamic_padding.PaddingMap() + padding_map.arg_index = idx + padding_map.shape_index = i + padding_map.padding_arg_index = real_shape_idx + padding_maps.append(padding_map) + real_shapes[core_idx].append( + math_ops.cast(input_shape_tensor[i], dtypes.uint32)) + + paddings = [] + for i, s in enumerate(padded_shape): + if input_shape[i].value: + # Don't pad if input shape is already static. + padding = [0, 0] + else: + if s.value: + # Pad to the given maximum value. + padding = [0, s.value - input_shape_tensor[i]] + else: + # If maximum value is not given, then pad to the maximum dimension + # among all the cores. + padding = [0, maximum_shapes[idx][i] - input_shape_tensor[i]] + paddings.append(padding) + + padded_input = array_ops.pad(input_tensor, paddings) + padded_inputs[core_idx].append(padded_input) + + num_replicas = len(padded_inputs) + for i in range(num_replicas): + padded_inputs[i].extend(real_shapes[i]) + + return padded_inputs, padding_maps def split_compile_and_replicate(computation, @@ -513,7 +661,8 @@ def split_compile_and_replicate(computation, infeed_queue=None, device_assignment=None, name=None, - use_tpu=True): + use_tpu=True, + maximum_shapes=None): """Builds graph operators that runs compilation and replicated computation. This is a lower level interface than replicate that returns a separate compile @@ -526,7 +675,11 @@ def split_compile_and_replicate(computation, computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. + have the same number of inputs. Each input can be a nested structure + containing values that are convertible to tensors. Note that passing an + N-dimension list of compatible values will result in a N-dimention list of + scalar tensors rather than a single Rank-N tensors. If you need different + behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the @@ -539,6 +692,15 @@ def split_compile_and_replicate(computation, use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU backends. Currently, only supports a default placement (computation is placed on GPU if one is available, and on CPU if not). + maximum_shapes: A nested structure of tf.TensorShape representing the shape + to which the respective component of each input element in each replica + should be padded. Any unknown dimensions (e.g. tf.Dimension(None) in a + tf.TensorShape or -1 in a tensor-like object) will be padded to the + maximum size of that dimension over all replicas. Note that if the input + dimension is already static, we won't do padding on it and we require the + maximum_shapes to have the same value or None on that dimension. The + structure of `maximum_shapes` needs to be the same as `inputs[0]`. + Returns: A list of lists with the first list corresponding to the compile op and the second a list of output tensors, indexed by `[replica_num][output_num]`. @@ -546,6 +708,10 @@ def split_compile_and_replicate(computation, ValueError: If all replicas do not have equal numbers of input tensors. ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. + ValueError: If the static `inputs` dimensions don't match with the values + given in `maximum_shapes`. + ValueError: If the structure of inputs per replica does not match + the structure of `maximum_shapes`. """ del name inputs = [[]] if inputs is None else inputs @@ -580,24 +746,32 @@ def split_compile_and_replicate(computation, if num_replicas == 0: return [] + # Checks all replicas have the same structure. + for i in xrange(1, num_replicas): + nest.assert_same_structure(inputs[0], inputs[i]) + + # Flatten inputs. + flat_inputs = [ + nest.flatten(per_replica_input) for per_replica_input in inputs + ] # Converts inputs to Tensors. - inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] + flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs] # Verifies that all replicas have matching numbers and types of inputs - input_types = [x.dtype for x in inputs[0]] - input_arity = len(input_types) + flat_input_types = [x.dtype for x in flat_inputs[0]] + input_arity = len(inputs[0]) + flat_input_arity = len(flat_input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) - types = [x.dtype for x in inputs[i]] - if types != input_types: - raise ValueError( - "Replicas must have matching input types. Replica 0 had " - "input types {}, replica {} had input types {}".format( - input_types, i, types)) + types = [x.dtype for x in flat_inputs[i]] + if types != flat_input_types: + raise ValueError("Replicas must have matching input types. Replica 0 had " + "input types {}, replica {} had input types {}".format( + flat_input_types, i, types)) arg_error = xla.check_function_argument_count( computation, input_arity, infeed_queue) @@ -616,13 +790,34 @@ def split_compile_and_replicate(computation, for i in inputs[0]]), infeed_queue.number_of_tuple_elements, arg_error)) + if maximum_shapes: + if infeed_queue: + raise ValueError( + "Dynamic input shapes are not supported with infeed queues") + + # Make sure maximum_shapes has the same structure as inputs. + nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False) + + # Flatten padded shapes. + flat_maximum_shapes = nest.flatten(maximum_shapes) + flat_maximum_shapes = [ + tensor_shape.TensorShape(s) for s in flat_maximum_shapes + ] + + flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes) + + serialized_padding_maps = [] + for padding_map in padding_maps: + serialized_padding_maps.append(padding_map.SerializeToString()) + metadata_kwargs["padding_map"] = serialized_padding_maps + graph = ops.get_default_graph() # Fan-in: Builds a TPUReplicatedInput node for each input. - computation_inputs = [] - for i in range(0, input_arity): - replicas = [inputs[replica][i] for replica in xrange(num_replicas)] - computation_inputs.append( + flat_replicated_inputs = [] + for i in range(0, len(flat_inputs[0])): + replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)] + flat_replicated_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") @@ -642,10 +837,26 @@ def split_compile_and_replicate(computation, # computation. This is to avoid orphaned TPUReplicatedInput nodes. # TODO(phawkins): consider instead pruning unused TPUReplicatedInput # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. - computation_inputs = [ + flat_replicated_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) - for i, x in enumerate(computation_inputs) + for i, x in enumerate(flat_replicated_inputs) ] + for i in flat_replicated_inputs: + # pylint: disable=protected-access + # Add an attribute to the identity node so that they could be removed in + # encapsulate TPU computation pass if unused. However we don't remove + # inputs when dynamic padding is enabled. + # TODO(rxsang): Use other ways except argument index in padding_map so + # outside compilation can work with dynamic padding correctly. + if maximum_shapes is None: + i.op._set_attr("_tpu_input_identity", + attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access + + # Unflatten the computation inputs to match original input structure. + computation_inputs = nest.pack_sequence_as( + structure=inputs[0], + flat_sequence=flat_replicated_inputs[:flat_input_arity]) # If there is an infeed queue, adds the dequeued values to the # computation's inputs. @@ -687,47 +898,12 @@ def split_compile_and_replicate(computation, vscope.set_use_resource(saved_use_resource) vscope.set_custom_getter(saved_custom_getter) - # If the computation returns `None`, make it an empty tuple. - if outputs is None: - outputs = tuple() - # If the computation only returned one value, makes it a tuple. - if not isinstance(outputs, (list, tuple)): - outputs = (outputs,) - - # Append `no_op` here so that fetching any return value of this function - # will trigger TPUExecute node. - outputs += (control_flow_ops.no_op(),) - try: - with ops.device(core(0)): - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - except Exception as e: - raise ValueError( - "TPU function return values must all either be Operations or " - "convertible to Tensors. Got '%s'" % str(e)) - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - "TPU functions must return zero-or more Tensor values followed by " - "zero or more Operations.") - output_arity = len(output_tensors) + outputs_is_flat = xla.is_flat(outputs) + if outputs_is_flat: + output_tensors, control_deps = _postprocess_flat_outputs(outputs) + else: + output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) - # Wraps outputs in Identity ops. Otherwise a replicated input copied - # straight to an output would bypass the replicate(). This would be bad - # because the TPUReplicatedInput/TPUReplicatedOutput operator would not - # be rewritten away, leading to a runtime error. - # TODO(phawkins): extend the rewrite to elide these nodes instead. - new_output_tensors = [] - for t in output_tensors: - with ops.device(t.device if t.device else core(0)): - new_output_tensors.append(array_ops.identity(t)) - output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: context.report_unsupported_operations() @@ -739,11 +915,6 @@ def split_compile_and_replicate(computation, attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core]) metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access - # Fan-out: Builds a TPUReplicatedOutput node for each output. - outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, - name="output{}".format(i)) - for i in xrange(output_arity)] - with ops.control_dependencies([metadata]): if use_tpu: compile_status = tpu_ops.tpu_compilation_result() @@ -753,39 +924,157 @@ def split_compile_and_replicate(computation, else: compile_status = control_flow_ops.no_op(name="compilation_status") - with ops.control_dependencies(output_operations): - if output_arity == 0: - # Returns a list of NoOps dependent on the replication Op, indexed by - # [replica_num]. - return [ - compile_status, [ - control_flow_ops.no_op(name="shard_%d" % i) - for i in range(num_replicas) - ] - ] - else: - # Wraps the outputs in identity operators so the names of any possible - # `fetch` nodes are preserved by the replication rewrite. - return [ - compile_status, [[ - array_ops.identity( - outputs[out][replica], - name="output_%d_shard_%d" % (out, replica)) - for out in xrange(output_arity) - ] - for replica in xrange(num_replicas)] + if not output_tensors: + # Returns a list of NoOps dependent on the replication Op, indexed by + # [replica_num]. + return [ + compile_status, + [ + control_flow_ops.group(control_deps, name="shard_%d" % i) + for i in range(num_replicas) + ] + ] + + # Fan-out: Builds a TPUReplicatedOutput node for each output. + replicated_outputs = [[] for i in xrange(num_replicas)] + for i, t in enumerate(output_tensors): + # Fan-out: Builds a TPUReplicatedOutput node for each output. + ys = tpu_ops.tpu_replicated_output( + t, num_replicas, name="output{}".format(i)) + + # Wraps the outputs in identity operators so the names of any possible + # `fetch` nodes are preserved by the replication rewrite. + with ops.control_dependencies(control_deps): + for replica in xrange(num_replicas): + replicated_outputs[replica].append( + array_ops.identity( + ys[replica], name="output_%d_shard_%d" % (i, replica))) + + if not outputs_is_flat: + replicated_outputs = [ + nest.pack_sequence_as(outputs, replica_outs) + for replica_outs in replicated_outputs + ] + + return [compile_status, replicated_outputs] + + +def _postprocess_flat_outputs(outputs): + """Validates non-flat outputs, add backs device assignments and other attrs. + + Args: + outputs: Output from `computation` inside `tpu.rewrite`. + + Returns: + Tensors and Operations extracted from outputs. + """ + # Following code segment is to preserve legacy behavior. Previously we only + # supported flat outputs and thus for consistency it was nice to convert even + # single element into a tuple. But now that we support arbitrary output + # structure, this is no longer necessary. + # TODO(b/121383831): Migrate all legacy use cases and delete this special + # case. + # If the computation returns `None`, make it an empty tuple. + if outputs is None: + outputs = tuple() + # If the computation only returned one value, makes it a tuple. + if not isinstance(outputs, collections.Sequence): + outputs = (outputs,) + + # Append `no_op` here so that fetching any return value of this function + # will trigger TPUExecute node. + outputs += (control_flow_ops.no_op(),) + try: + with ops.device(core(0)): + outputs = [ + o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + for o in outputs ] + except Exception as e: + raise ValueError( + "TPU function return values must all either be Operations or " + "convertible to Tensors. Got '%s'" % str(e)) + + # Separates the returned Operations and Tensors. + output_operations = [o for o in outputs if isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] + + if outputs != output_tensors + output_operations: + raise ValueError( + "TPU functions must return zero-or more Tensor values followed by " + "zero or more Operations.") + + # Wraps outputs in Identity ops. Otherwise a replicated input copied + # straight to an output would bypass the replicate(). This would be bad + # because the TPUReplicatedInput/TPUReplicatedOutput operator would not + # be rewritten away, leading to a runtime error. + # TODO(phawkins): extend the rewrite to elide these nodes instead. + new_output_tensors = [] + for t in output_tensors: + with ops.device(t.device if t.device else core(0)): + o = array_ops.identity(t) + # pylint: disable=protected-access + o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access + new_output_tensors.append(o) + return new_output_tensors, output_operations -def shard(computation, - inputs=None, - num_shards=1, - input_shard_axes=None, - outputs_from_all_shards=True, - output_shard_axes=None, - infeed_queue=None, - device_assignment=None, - name=None): +def _postprocess_non_flat_outputs(outputs): + """Validates non-flat outputs, add backs device assignments and other attrs. + + Args: + outputs: Output from `computation` inside `tpu.rewrite`. + + Returns: + Tensors extracted from outputs and an empty list because Operations are not + allowed in non-flat outputs.. + """ + + # Flatten output items. + flat_outputs = nest.flatten(outputs) + + # Convert all non-Operation outputs to Tensors. + for i, o in enumerate(flat_outputs): + if isinstance(o, ops.Operation): + raise ValueError( + "tpu.rewrite does not support Operation as return value in non-flat " + "output structure. You can set returned Operations as control " + "dependencies of returned Tensors so Operations are triggered when " + 'Tensors are evaluated. Operation found: "%s"' % o.name) + + try: + o = ops.convert_to_tensor(o) + except Exception as e: + raise ValueError( + "TPU function return values must all either be Operations or " + 'convertible to Tensors. Got error: "%s"' % str(e)) + + # Wraps outputs in Identity ops. Otherwise a replicated input copied + # straight to an output would bypass the replicate(). This would be bad + # because the TPUReplicatedInput/TPUReplicatedOutput operator would not + # be rewritten away, leading to a runtime error. + # TODO(phawkins): extend the rewrite to elide these nodes instead. + with ops.device(core(0)): + o = array_ops.identity(o) + # pylint: disable=protected-access + o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access + flat_outputs[i] = array_ops.identity(o) + + # All flat_outputs are Tensors, and no Operations. + return flat_outputs, [] + + +def split_compile_and_shard(computation, + inputs=None, + num_shards=1, + input_shard_axes=None, + outputs_from_all_shards=True, + output_shard_axes=None, + infeed_queue=None, + device_assignment=None, + name=None): """Shards `computation` for parallel execution. `inputs` must be a list of Tensors or None (equivalent to an empty list), each @@ -801,9 +1090,6 @@ def shard(computation, return x + 3 ... = shard(computation, ...) - TODO(phawkins): consider adding support for broadcasting Tensors passed - as inputs. - If `outputs_from_all_shards` is true, the outputs from all shards of `computation` are concatenated back together along their `output_shards_axes`. Otherwise, each output is taken from an arbitrary shard. @@ -839,12 +1125,14 @@ def shard(computation, is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. Returns: - A list of output tensors. + A tuple of (compile op, [output tensors]). Raises: ValueError: If num_shards <= 0 ValueError: If len(input_shard_axes) != len(inputs) ValueError: If len(output_shard_axes) != len(outputs from `computation`) """ + # TODO(phawkins): consider adding support for broadcasting Tensors passed as + # inputs. if num_shards <= 0: raise ValueError("num_shards must be a positive integer.") @@ -874,7 +1162,7 @@ def shard(computation, else: transposed_inputs = [[]] * num_shards - outputs = replicate( + compile_op, outputs = split_compile_and_replicate( computation, transposed_inputs, infeed_queue=infeed_queue, @@ -891,7 +1179,7 @@ def shard(computation, # one so it can be used as a control dependency or fetch node. # TODO(b/36647078) remove disable when pylint bug is fixed. # pylint: disable=indexing-exception - return [outputs[0]] + return compile_op, [outputs[0]] # pylint: enable=indexing-exception # TODO(b/36647078) remove disable when pylint bug is fixed. @@ -925,7 +1213,87 @@ def shard(computation, # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. results.append(x[0]) - return results + return compile_op, results + + +def shard(computation, + inputs=None, + num_shards=1, + input_shard_axes=None, + outputs_from_all_shards=True, + output_shard_axes=None, + infeed_queue=None, + device_assignment=None, + name=None): + """Shards `computation` for parallel execution. + + `inputs` must be a list of Tensors or None (equivalent to an empty list), each + of which has a corresponding split axis (from `input_shard_axes`). Each input + is split into `num_shards` pieces along the corresponding axis, and + computation is applied to each shard in parallel. + + Tensors are broadcast to all shards if they are lexically captured by + `computation`. e.g., + + x = tf.constant(7) + def computation(): + return x + 3 + ... = shard(computation, ...) + + TODO(phawkins): consider adding support for broadcasting Tensors passed + as inputs. + + If `outputs_from_all_shards` is true, the outputs from all shards of + `computation` are concatenated back together along their `output_shards_axes`. + Otherwise, each output is taken from an arbitrary shard. + + Inputs and outputs of the computation must be at least rank-1 Tensors. + + Args: + computation: A Python function that builds a computation to apply to each + shard of the input. + inputs: A list of input tensors or None (equivalent to an empty list). Each + input tensor has a corresponding shard axes, given by `input_shard_axes`, + which must have size divisible by `num_shards`. + num_shards: The number of shards. + input_shard_axes: A list of dimensions along which to shard `inputs`, or + `None`. `None` means "shard all inputs along dimension 0". If not `None`, + there must be one dimension per input. + outputs_from_all_shards: Boolean or list of boolean. For each output, if + `True`, outputs from all shards are concatenated along the corresponding + `output_shard_axes` entry. Otherwise, each output is taken + from an arbitrary shard. If the argument is a boolean, the argument's + value is used for each output. + output_shard_axes: A list of dimensions along which to concatenate the + outputs of `computation`, or `None`. `None` means "concatenate all outputs + along dimension 0". If not `None`, there must be one dimension per output. + Ignored if `outputs_from_all_shards` is False. + infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs + of `computation`. + device_assignment: If not `None`, a `DeviceAssignment` describing the + mapping between logical cores in the computation with physical cores in + the TPU topology. Uses a default device assignment if `None`. The + `DeviceAssignment` may be omitted if each shard of the computation uses + only one core, and there is either only one shard, or the number of shards + is equal to the number of cores in the TPU system. + name: (Deprecated) Does nothing. + Returns: + A list of output tensors. + Raises: + ValueError: If num_shards <= 0 + ValueError: If len(input_shard_axes) != len(inputs) + ValueError: If len(output_shard_axes) != len(outputs from `computation`) + """ + return split_compile_and_shard( + computation, + inputs=inputs, + num_shards=num_shards, + input_shard_axes=input_shard_axes, + outputs_from_all_shards=outputs_from_all_shards, + output_shard_axes=output_shard_axes, + infeed_queue=infeed_queue, + device_assignment=device_assignment, + name=name)[1] def batch_parallel(computation, @@ -1004,6 +1372,11 @@ def rewrite(computation, All `Operation`s constructed during `computation` will be executed when evaluating any of the returned output tensors, not just the ones returned. inputs: A list of input tensors or `None` (equivalent to an empty list). + Each input can be a nested structure containing values that are + convertible to tensors. Note that passing an N-dimension list of + compatible values will result in a N-dimention list of scalar tensors + rather than a single Rank-N tensors. If you need different behavior, + convert part of inputs to tensors with `tf.convert_to_tensor`. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to `computation`. device_assignment: if not `None`, a `DeviceAssignment` describing the @@ -1012,11 +1385,15 @@ def rewrite(computation, case the core attached to task 0, TPU device 0 is used. name: (Deprecated) Does nothing. Returns: - A list of output tensors. + Same data structure as if computation(*inputs) is called directly with some + exceptions for correctness. Exceptions include: + 1) None output: a NoOp would be returned which control-depends on + computation. + 2) Single value output: A tuple containing the value would be returned. + 3) Operation-only outputs: a NoOp would be returned which + control-depends on computation. + TODO(b/121383831): Investigate into removing these special cases. """ - if inputs is not None and not isinstance(inputs, (list, tuple)): - raise TypeError("tpu.rewrite() inputs must be a list or tuple") - # TODO(b/36647078) remove disable when pylint bug is fixed. # pylint: disable=indexing-exception return replicate( diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 672462447944b777375331d49727c4d5366cf295..ed1e0f0401a96c34e6ff9323685857b64e10bd14 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -21,6 +21,7 @@ from __future__ import print_function from contextlib import contextmanager import copy +from tensorflow.contrib.tpu.python.tpu import _tpu_estimator_embedding from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib @@ -192,8 +193,14 @@ class _InternalTPUContext(object): ``` """ - def __init__(self, config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu=True): + def __init__(self, + config, + train_batch_size, + eval_batch_size, + predict_batch_size, + use_tpu, + eval_on_tpu=True, + embedding_config_spec=None): self._config = config self._train_batch_size = train_batch_size self._eval_batch_size = eval_batch_size @@ -208,7 +215,7 @@ class _InternalTPUContext(object): use_tpu and config.tpu_config.num_cores_per_replica) self._mode = None num_cores_per_replica = config.tpu_config.num_cores_per_replica - if num_cores_per_replica: + if self._model_parallelism_enabled: self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[ num_cores_per_replica] else: @@ -216,6 +223,8 @@ class _InternalTPUContext(object): self._lazy_tpu_system_metadata_dict = {} # key by master address self._lazy_device_assignment_dict = {} # key by master address self._lazy_validation_dict = {} # key by ModeKeys + self._embedding_config_spec = embedding_config_spec + self._lazy_embedding_config_dict = {} # key by master address def _assert_mode(self): if self._mode is None: @@ -293,6 +302,30 @@ class _InternalTPUContext(object): self._lazy_device_assignment_dict[master] = device_assignment return device_assignment + @property + def embedding_config(self): + """Returns the embedding config based on current mode.""" + master = self._get_master_address() + if master in self._lazy_embedding_config_dict: + embedding_config = self._lazy_embedding_config_dict[master] + else: + embedding_config = None + if self._use_tpu and self._embedding_config_spec: + embedding_config = _tpu_estimator_embedding.EmbeddingConfig( + self._embedding_config_spec, self._train_batch_size, + self._eval_batch_size, self.num_hosts, self.num_cores, master) + if not embedding_config.has_embedding_tables(): + embedding_config = None + self._lazy_embedding_config_dict[master] = embedding_config + + if embedding_config is not None: + mode = self._assert_mode() + # Dynamically attach tpu_embedding based on mode. With + # this, we could keep embedding_config immutable but call site always + # accesses the unified API '.tpu_embedding'. + embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode) + return embedding_config + @property def model_parallelism_enabled(self): return self._model_parallelism_enabled @@ -710,11 +743,15 @@ class _OneCoreTPUContext(_InternalTPUContext): def _get_tpu_context(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu): + predict_batch_size, use_tpu, eval_on_tpu, + embedding_config_spec): """Returns an instance of `_InternalTPUContext`.""" if (config.tpu_config.num_shards == 1 and config.tpu_config.num_cores_per_replica is None): + if embedding_config_spec is not None: + raise ValueError('Setting TPUConfig.num_shards==1 is unsupported ' + 'when embedding_config_spec is not None.') logging.warning( 'Setting TPUConfig.num_shards==1 is an unsupported behavior. ' 'Please fix as soon as possible (leaving num_shards as None.)') @@ -722,4 +759,5 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size, predict_batch_size, use_tpu) return _InternalTPUContext(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu) + predict_batch_size, use_tpu, eval_on_tpu, + embedding_config_spec) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py index ccba8a46c7cad0337119672e02314684f4451479..0e4597bd6fae500c93f74fcb1b16a39739d2310c 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py @@ -28,6 +28,7 @@ from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.ops import gen_tpu_ops from tensorflow.contrib.tpu.proto import tpu_embedding_configuration_pb2 as elc from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -43,19 +44,6 @@ from tensorflow.python.ops import variables TRAINING = elc.TPUEmbeddingConfiguration.TRAINING INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE -# TODO(shizhiw): A better interface is to make `num_hosts` and -# `num_cores_per_host` optional parameters for `TPUEmbedding` -# constructor. Usually they can be automatically detected, but -# user can also specify them for debugging (b/112112496). -# Auto-detection can be done with `tpu_system_metadata.py`. -_MASTER_JOB = 'tpu_worker' -_HOST_PATTERN = '/job:tpu_worker/task:{}/device:CPU:0' -_NUM_CORES_PER_HOST = 8 - -_TEST_MASTER_JOB = None -_TEST_HOST = '/replica:0/task:0/device:CPU:0' -_TEST_NUM_CORES_PER_HOST = 2 - class TableConfig( collections.namedtuple( @@ -112,6 +100,25 @@ class TableConfig( initializer, combiner) +AdamSlotVariableNames = collections.namedtuple( + 'AdamSlotVariableNames', ['m', 'v']) + +AdagradSlotVariableName = collections.namedtuple( + 'AdagradSlotVariableName', ['accumulator']) + +AdamSlotVariables = collections.namedtuple( + 'AdamSlotVariables', ['m', 'v']) + +AdagradSlotVariable = collections.namedtuple( + 'AdagradSlotVariable', ['accumulator']) + +VariablesAndOps = collections.namedtuple( + 'VariablesAndOps', + ['embedding_variables_by_table', 'slot_variables_by_table', + 'load_ops', 'retrieve_ops'] +) + + # TODO(shizhiw): Factor `use_gradient_accumulation` and # `pipeline_execution_with_tensor_core` out of `_OptimizationParameters`. class _OptimizationParameters(object): @@ -248,6 +255,7 @@ class TPUEmbedding(object): sparse_features_list.append(sparse_features) enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list) + embedding_variables_and_ops = embedding.create_variables_and_ops() def computation(): activations = embedding.get_activations() @@ -273,6 +281,7 @@ class TPUEmbedding(object): embedding.config_proto)) sess.run(variables.global_variables_initializer()) sess.run(embedding.init_ops) + sess.run(embedding_variables_and_ops.load_ops) sess.run(enqueue_ops) loss_val = sess.run(loss) ``` @@ -301,10 +310,9 @@ class TPUEmbedding(object): table_to_config_dict, feature_to_table_dict, batch_size, - num_hosts, mode, - optimization_parameters=None, - tpu_embedding_test=False): + master, + optimization_parameters=None): """API for using TPU for embedding lookups. Args: @@ -315,12 +323,11 @@ class TPUEmbedding(object): to string of table name. Feature refers to ids to lookup in embedding table, e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. batch_size: An `int` representing the global batch size. - num_hosts: An `int` representing the number of TPU hosts. mode: `TRAINING` or `INFERENCE`. + master: A `string` representing the TensorFlow master to use. optimization_parameters: `AdagradParameters`, `AdamParameters`, `Stochasticgradientdescentparameters`. Must be set in training and must be `None` in inference. - tpu_embedding_test: A `bool`. Only used for testing. Raises: ValueError: if any input is invalid. @@ -337,15 +344,17 @@ class TPUEmbedding(object): self._batch_size = batch_size - if tpu_embedding_test: - self._num_hosts = 1 - self._hosts = [_TEST_HOST] - self._num_cores_per_host = _TEST_NUM_CORES_PER_HOST - else: - self._num_hosts = num_hosts - self._hosts = [_HOST_PATTERN.format(i) for i in range(self._num_hosts)] - self._num_cores_per_host = _NUM_CORES_PER_HOST - self._num_cores = self._num_cores_per_host * self._num_hosts + self._master = master + self._tpu_system_metadata = ( + tpu_system_metadata_lib._query_tpu_system_metadata(self._master)) # pylint: disable=protected-access + if self._tpu_system_metadata.num_cores == 0: + raise ValueError('TPUEmbedding needs TPUs, but master {} does not have ' + 'TPUs.'.format(self._master)) + self._num_hosts = self._tpu_system_metadata.num_hosts + self._hosts = [device.name for device in self._tpu_system_metadata.devices + if 'device:CPU:' in device.name] + self._num_cores_per_host = self._tpu_system_metadata.num_of_cores_per_host + self._num_cores = self._tpu_system_metadata.num_cores _validate_batch_size(self._batch_size, self._num_cores) self._batch_size_per_core = self._batch_size // self._num_cores @@ -379,9 +388,6 @@ class TPUEmbedding(object): self._config_proto = self._create_config_proto() - self._create_variables_and_ops() - self._init_ops.extend(self._load_parameters_ops) - @property def hosts(self): """A list of device names for CPU hosts. @@ -389,7 +395,7 @@ class TPUEmbedding(object): Returns: A list of device names for CPU hosts. """ - return self._hosts + return copy.copy(self._hosts) # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and # to be consistent with `tpu_embedding_configuration.proto`. @@ -447,23 +453,9 @@ class TPUEmbedding(object): """ return self._init_ops - # TODO(shizhiw): get table variables the same way as getting slot variables. @property - def table_to_table_variables_dict(self): - return copy.copy(self._table_to_table_variables_dict) - - def get_slot_names(self): - """Return a list of the names of slots created by `TPUEmbedding`.""" - return self._optimizer_handler.get_slot_names() - - def get_slot(self, table, name): - """Return a slot named `name` create for `table` by `TPUEmbedding`.""" - return self._optimizer_handler.get_slot(table, name) - - # TODO(shizhiw): expose load to user too? - @property - def retrieve_parameters_ops(self): - return self._retrieve_parameters_ops + def feature_to_table_dict(self): + return copy.copy(self._feature_to_table_dict) def _create_config_proto(self): """Create `TPUEmbeddingConfiguration`.""" @@ -495,30 +487,63 @@ class TPUEmbedding(object): return config_proto - def _create_variables_and_ops(self): - """Create embedding variables and return ops to load them into TPU.""" - self._load_parameters_ops = [] - self._retrieve_parameters_ops = [] - self._table_to_table_variables_dict = {} + def create_variables_and_ops(self, embedding_variable_name_by_table=None, + slot_variable_names_by_table=None): + """Create embedding and slot variables, with ops to load and retrieve them. + + Args: + embedding_variable_name_by_table: A dictionary mapping from string of + table name to string of embedding variable name. If `None`, + defaults from `get_default_slot_variable_names()` will be used. + slot_variable_names_by_table: A dictionary mapping from string of table + name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If + `None`, defaults from `get_default_slot_variable_names()` will be used. + + Returns: + `tpu_embedding.VariablesAndOps` with: + A dictionary mapping from string of table name to embedding variables, + A dictionary mapping from string of table name to AdagradSlotVariable, + AdamSlotVariables etc with slot variables, + A list of ops to load embedding and slot variables on CPU to TPU, + A list of ops to retrieve embedding and slot variables from TPU to CPU. + """ + embedding_variables_by_table = {} + slot_variables_by_table = {} + load_ops = [] + retrieve_ops = [] for table in self._table_to_config_dict: + if embedding_variable_name_by_table: + embedding_variable_name = embedding_variable_name_by_table[table] + else: + embedding_variable_name = table + if slot_variable_names_by_table: + slot_variable_names = slot_variable_names_by_table[table] + else: + slot_variable_names = ( + self._optimizer_handler.get_default_slot_variable_names(table)) + device_fn = _create_device_fn(self._hosts) with ops.device(device_fn): - # TODO(shizhiw): allow user to specify variable name so that - # they could make the name consistent with CPU etc. - variable_name = table table_variables = _create_partitioned_variables( - name=variable_name, + name=embedding_variable_name, num_hosts=self._num_hosts, vocabulary_size=self._table_to_config_dict[table].vocabulary_size, embedding_dimension=self._table_to_config_dict[table].dimension, initializer=self._table_to_config_dict[table].initializer, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - self._table_to_table_variables_dict[table] = table_variables - - self._optimizer_handler.create_variables_and_ops( - table, variable_name, self._num_hosts, - self._table_to_config_dict[table], table_variables, - self._load_parameters_ops, self._retrieve_parameters_ops) + embedding_variables_by_table[table] = table_variables + + slot_variables_for_table, load_ops_for_table, retrieve_ops_for_table = ( + self._optimizer_handler.create_variables_and_ops( + table, slot_variable_names, self._num_hosts, + self._table_to_config_dict[table], table_variables) + ) + slot_variables_by_table[table] = slot_variables_for_table + load_ops.extend(load_ops_for_table) + retrieve_ops.extend(retrieve_ops_for_table) + return VariablesAndOps(embedding_variables_by_table, + slot_variables_by_table, + load_ops, retrieve_ops) def _create_dummy_table_variables(self): """Create dummy embedding table variables. @@ -812,13 +837,11 @@ class _OptimizerHandler(object): def set_optimization_parameters(self, table_descriptor): raise NotImplementedError() - def create_variables_and_ops(self, table, variable_name): - raise NotImplementedError() - - def get_slot_names(self): + def get_default_slot_variable_names(self, table): raise NotImplementedError() - def get_slot(self, table, name): + def create_variables_and_ops(self, table, slot_variable_names, num_hosts, + table_config, table_variables): raise NotImplementedError() @@ -832,21 +855,24 @@ class _AdagradHandler(_OptimizerHandler): def set_optimization_parameters(self, table_descriptor): table_descriptor.optimization_parameters.adagrad.SetInParent() - def create_variables_and_ops(self, table, variable_name, num_hosts, - table_config, table_variables, - load_parameters_ops, retrieve_parameters_ops): - optimizer_name = 'Adagrad' + def get_default_slot_variable_names(self, table): + return AdagradSlotVariableName('{}/{}'.format(table, 'Adagrad')) + + def create_variables_and_ops(self, table, slot_variable_names, num_hosts, + table_config, table_variables): accumulator_initializer = init_ops.constant_initializer( self._optimization_parameters.initial_accumulator) accumulator_variables = _create_partitioned_variables( - name='%s/%s' % (variable_name, optimizer_name), + name=slot_variable_names.accumulator, num_hosts=num_hosts, vocabulary_size=table_config.vocabulary_size, embedding_dimension=table_config.dimension, collections=[ops.GraphKeys.GLOBAL_VARIABLES], initializer=accumulator_initializer) + slot_variables = AdagradSlotVariable(accumulator_variables) - self._table_to_accumulator_variables_dict[table] = accumulator_variables + load_ops = [] + retrieve_ops = [] for host_id, table_variable, accumulator_variable in (zip( range(num_hosts), table_variables, accumulator_variables)): with ops.colocate_with(table_variable): @@ -866,17 +892,9 @@ class _AdagradHandler(_OptimizerHandler): state_ops.assign(table_variable, retrieved_table), state_ops.assign(accumulator_variable, retrieved_accumulator)) - load_parameters_ops.append(load_parameters_op) - retrieve_parameters_ops.append(retrieve_parameters_op) - - def get_slot_names(self): - return ['accumulator'] - - def get_slot(self, table, name): - if name not in self.get_slot_names(): - raise ValueError('Adagrad has {} as slot names; got {}.' - .format(self.get_slot_names(), name)) - return self._table_to_accumulator_variables_dict[table] + load_ops.append(load_parameters_op) + retrieve_ops.append(retrieve_parameters_op) + return slot_variables, load_ops, retrieve_ops class _AdamHandler(_OptimizerHandler): @@ -899,13 +917,15 @@ class _AdamHandler(_OptimizerHandler): table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = ( self._optimization_parameters.sum_inside_sqrt) - def create_variables_and_ops(self, table, variable_name, num_hosts, - table_config, table_variables, - load_parameters_ops, retrieve_parameters_ops): - optimizer_name = 'Adam' + def get_default_slot_variable_names(self, table): + return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'), + '{}/{}/v'.format(table, 'Adam')) + + def create_variables_and_ops(self, table, slot_variable_names, num_hosts, + table_config, table_variables): m_initializer = init_ops.zeros_initializer() m_variables = _create_partitioned_variables( - name='%s/%s/m' % (variable_name, optimizer_name), + name=slot_variable_names.m, num_hosts=num_hosts, vocabulary_size=table_config.vocabulary_size, embedding_dimension=table_config.dimension, @@ -913,16 +933,16 @@ class _AdamHandler(_OptimizerHandler): initializer=m_initializer) v_initializer = init_ops.zeros_initializer() v_variables = _create_partitioned_variables( - name='%s/%s/v' % (variable_name, optimizer_name), + name=slot_variable_names.v, num_hosts=num_hosts, vocabulary_size=table_config.vocabulary_size, embedding_dimension=table_config.dimension, collections=[ops.GraphKeys.GLOBAL_VARIABLES], initializer=v_initializer) + slot_variables = AdamSlotVariables(m_variables, v_variables) - self._table_to_m_variables_dict[table] = m_variables - self._table_to_v_variables_dict[table] = v_variables - + load_ops = [] + retrieve_ops = [] for host_id, table_variable, m_variable, v_variable in (zip( range(num_hosts), table_variables, m_variables, v_variables)): @@ -945,20 +965,9 @@ class _AdamHandler(_OptimizerHandler): state_ops.assign(m_variable, retrieved_m), state_ops.assign(v_variable, retrieved_v)) - load_parameters_ops.append(load_parameters_op) - retrieve_parameters_ops.append(retrieve_parameters_op) - - def get_slot_names(self): - return ['m', 'v'] - - def get_slot(self, table, name): - if name == 'm': - return self._table_to_m_variables_dict[table] - elif name == 'v': - return self._table_to_v_variables_dict[table] - else: - raise ValueError('Adam has {} as slot names; got {}.' - .format(self.get_slot_names(), name)) + load_ops.append(load_parameters_op) + retrieve_ops.append(retrieve_parameters_op) + return slot_variables, load_ops, retrieve_ops class _StochasticGradientDescentHandler(_OptimizerHandler): @@ -968,11 +977,15 @@ class _StochasticGradientDescentHandler(_OptimizerHandler): (table_descriptor.optimization_parameters.stochastic_gradient_descent .SetInParent()) - def create_variables_and_ops(self, table, variable_name, num_hosts, - table_config, table_variables, - load_parameters_ops, retrieve_parameters_ops): + def get_default_slot_variable_names(self, table): + return None + + def create_variables_and_ops(self, table, slot_variable_names, num_hosts, + table_config, table_variables): del table_config + load_ops = [] + retrieve_ops = [] for host_id, table_variable in (zip( range(num_hosts), table_variables)): with ops.colocate_with(table_variable): @@ -992,14 +1005,9 @@ class _StochasticGradientDescentHandler(_OptimizerHandler): retrieve_parameters_op = control_flow_ops.group( state_ops.assign(table_variable, retrieved_table)) - load_parameters_ops.append(load_parameters_op) - retrieve_parameters_ops.append(retrieve_parameters_op) - - def get_slot_names(self): - return [] - - def get_slot(self, table, name): - raise ValueError('Stochastic gradient descent does not have slot variable.') + load_ops.append(load_parameters_op) + retrieve_ops.append(retrieve_parameters_op) + return None, load_ops, retrieve_ops def _get_optimization_handler(optimization_parameters): @@ -1077,34 +1085,3 @@ def _create_partitioned_variables(name, initializer=initializer, collections=collections, trainable=False)) - - -@ops.RegisterGradient('TPUEmbeddingActivations') -def _embedding_activations_grad(activations_op, grad_wrt_activations): - """Saves the gradient of embedding activations ops in a graph collection.""" - g = ops.get_default_graph() - table_id = activations_op.get_attr('table_id') - lookup_id = activations_op.get_attr('lookup_id') - table_gradients = g.get_collection_ref( - 'tpu_embedding_gradients_table_%d' % table_id) - - if not table_gradients: - raise RuntimeError( - 'Gradients for TPUEmbedding have been generated in non-training mode. ' - 'This is not expected. Consider putting your Optimizer.minimize code ' - 'behind the training mode condition check. For Estimator, you can ' - 'do \n\n' - ' if mode == tf.estimator.ModeKeys.TRAIN:\n' - ' train_op = opt.minimize(loss)\n' - '\n') - - table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) - return [ - # RegisterGradient requires that value be returned for all inputs. Since - # the first argument (tpu_gradient_variable_{table_name}) has shape [1], - # we will return zeros(shape=[1]). The actual gradient w.r.t. the - # embedding activations (grad_wrt_activations) has the same shape as the - # activations returned by embedding_activations. - array_ops.zeros(arg.shape, dtype=dtypes.float32) - for arg in activations_op.inputs - ] diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 7171587ff7298982423a5046d85d1970a4d6b1cb..6a3ed9bb79502505d64156c6a405d9d57ee83eb5 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -31,20 +31,27 @@ import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.contrib.tpu.python.tpu import tensor_tracer +from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.ops import tpu_ordinal_selector_op +from tensorflow.contrib.tpu.python.tpu import _tpu_estimator_embedding from tensorflow.contrib.tpu.python.tpu import error_handling +from tensorflow.contrib.tpu.python.tpu import functional as tpu_functional from tensorflow.contrib.tpu.python.tpu import session_support +from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_context from tensorflow.contrib.tpu.python.tpu import tpu_feed from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.contrib.tpu.python.tpu import util as util_lib +from tensorflow.contrib.tpu.python.tpu._tpu_estimator_embedding import AdamParameters # pylint: disable=unused-import +from tensorflow.contrib.tpu.python.tpu._tpu_estimator_embedding import EmbeddingConfigSpec # pylint: disable=unused-import from tensorflow.contrib.training.python.training import hparam from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest as data_nest from tensorflow.python.estimator import estimator as estimator_lib @@ -53,6 +60,7 @@ from tensorflow.python.estimator.export import export_output as export_output_li from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -88,6 +96,7 @@ _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' _REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' +_KEY_WHEN_PREDICTIONS_IS_A_TENSOR = '_key_when_predictions_is_a_tensor' # Ideally _USE_TPU_KEY should be reserved as well. However there are already # models that make use of this key, thus it can not be reserved now to prevent @@ -118,6 +127,16 @@ def _is_iterable(obj): return False +class CatchInvalidHostcallFunctions(control_flow_ops.XLAControlFlowContext): + + def AddOp(self, op): + if op.type in [ + 'AudioSummary', 'AudioSummaryV2', 'HistogramSummary', 'ImageSummary', + 'MergeSummary', 'ScalarSummary', 'TensorSummary', 'TensorSummaryV2' + ]: + raise ValueError('Use tf.contrib.summary inside of host_calls.') + + def _create_global_step(graph): graph = graph or ops.get_default_graph() if training.get_global_step(graph) is not None: @@ -335,12 +354,25 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] + loss = self.loss + if tensor_tracer.TensorTracer.is_enabled() \ + and self.train_op is not None: + tt = tensor_tracer.TensorTracer() + (loss, tracing_calls) = tt.trace_cpu(ops.get_default_graph(), + loss, self.train_op) + tracing_call_ret = _OutfeedHostCall.create_cpu_hostcall(tracing_calls) + tracing_functions = tracing_call_ret.values() + if tracing_functions: + if hooks: + hooks.extend([_OutfeedHostCallHook(tracing_functions)]) + else: + hooks = [_OutfeedHostCallHook(tracing_functions)] hooks = tuple(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( mode=self.mode, predictions=self.predictions, - loss=self.loss, + loss=loss, train_op=self.train_op, eval_metric_ops=eval_metric_ops, export_outputs=self.export_outputs, @@ -411,13 +443,24 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): ctx, enqueue_ops, dequeue_ops, + tpu_compile_op, run_infeed_loop_on_coordinator=True, - rendezvous=None): + rendezvous=None, + master=None, + session_config=None, + tpu_init_ops=None): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops self._rendezvous = rendezvous - + self._master = master + self._session_config = session_config + self._init_ops = list(tpu_init_ops or []) + if ctx.embedding_config is None: + self._embedding_layer_config = None + else: + self._embedding_layer_config = ( + ctx.embedding_config.tpu_embedding.config_proto) self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) @@ -425,15 +468,14 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): self._feed_error = None self._finished = False self._should_initialize_tpu = True + self._tpu_compile_op = tpu_compile_op def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() if self._should_initialize_tpu: - self._init_ops = [tpu.initialize_system(job=self._master_job)] self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] else: - self._init_ops = [] self._finalize_ops = [] summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() @@ -474,12 +516,34 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def _create_infeed_controller(self, name, target, args): return _OpQueueContext(name=name, target=target, args=args) + def _assertCompilationSucceeded(self, result, coord): + proto = tpu_compilation_result.CompilationResultProto() + proto.ParseFromString(result) + if proto.status_error_message: + logging.error('Compilation failed: {}'.format(proto.status_error_message)) + coord.request_stop() + else: + logging.info('Compilation succeeded') + def after_create_session(self, session, coord): - logging.info('Init TPU system') - start = time.time() + if self._should_initialize_tpu: + logging.info('Init TPU system') + start = time.time() + with ops.Graph().as_default(): + with tf_session.Session( + self._master, config=self._session_config) as sess: + sess.run( + tpu.initialize_system( + job=self._master_job, + embedding_config=self._embedding_layer_config)) + logging.info('Initialized TPU in %d seconds', time.time() - start) + session.run(self._init_ops, options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) - logging.info('Initialized TPU in %d seconds', time.time() - start) + + if os.environ.get('TPU_SPLIT_COMPILE_AND_EXECUTE', '') == '1': + logging.info('Compiling user program: this may take a while...') + self._assertCompilationSucceeded(session.run(self._tpu_compile_op), coord) self._infeed_controller = self._create_infeed_controller( name='InfeedController', target=self._run_infeed, args=(session,)) @@ -521,13 +585,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None): + def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op, + rendezvous=None, master=None, session_config=None): super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( ctx, enqueue_ops, dequeue_ops, + tpu_compile_op=tpu_compile_op, run_infeed_loop_on_coordinator=False, - rendezvous=rendezvous) + rendezvous=rendezvous, + master=master, + session_config=session_config) def _create_infeed_controller(self, name, target, args): return _OpSignalOnceQueueContext(name=name, target=target, args=args) @@ -809,6 +877,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( """Generates the per_host enqueue ops.""" control_deps = [] per_host_sharded_inputs = [] + sparse_features_list = [] num_replicas_per_host = ctx.num_of_replicas_per_host cached_signals = None with ops.device(device): @@ -827,6 +896,10 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( else: cached_signals = signals + features, labels, sparse_features = ( + _tpu_estimator_embedding.split_inputs(ctx, features, labels)) + sparse_features_list.append(sparse_features) + inputs_structure_recorder.validate_and_record_structure( features, labels) flattened_inputs = ( @@ -855,6 +928,11 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( tpu_ordinal_function=tpu_ordinal_function_impl) captured_infeed_queue.capture(infeed_queue) + if ctx.embedding_config: + per_host_enqueue_ops.extend( + ctx.embedding_config.tpu_embedding.generate_enqueue_ops( + sparse_features_list)) + if signals is None: return per_host_enqueue_ops else: @@ -1264,6 +1342,44 @@ class _InputPipeline(object): logging.warn(err_msg) +def call_computation(computation, + experimental_exported_model_uses_all_cores=True): + """Call computation. + + computation uses a single-core for TPU inference. If + `experimental_exported_model_uses_all_cores` is `True`, this function will + round-robin + computation among all TPU cores visible to the host; otherwise, it will use + a single core. + + Args: + computation: A Python function that takes no inputs and builds computation + graph. If `computation` returns m outputs, this function will return a + list of m Tensors. + experimental_exported_model_uses_all_cores: Whether to round-robin among all + cores visible to the host, or to use a single core. + + Returns: + A list of output tensors. + """ + if experimental_exported_model_uses_all_cores: + # Using `TPUPartitionedCall` makes it possible to target a different + # TPU core with every `Session.run()` call. Note that the entire inference + # graph executes on a single core, and that invocations of this graph + # will round-robin among the cores attached to a host. + @function.Defun() + def tpu_subgraph(): + return computation() + + return tpu_functional.TPUPartitionedCall( + args=tpu_subgraph.captured_inputs, + device_ordinal=tpu_ordinal_selector_op.tpu_ordinal_selector(), + Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg], + f=tpu_subgraph) + else: + return computation() + + class _ModelFnWrapper(object): """A `model_fn` wrapper. @@ -1283,6 +1399,12 @@ class _ModelFnWrapper(object): def call_without_tpu(self, features, labels, is_export_mode): return self._call_model_fn(features, labels, is_export_mode=is_export_mode) + def _add_embedding_features(self, features): + if self._ctx.embedding_config: + tpu_embedding_ = self._ctx.embedding_config.tpu_embedding + embedding_activations = tpu_embedding_.get_activations() + features.update(embedding_activations) + def convert_to_single_tpu_train_step(self, dequeue_fn): """Converts user provided model_fn` as a single train step on TPU. @@ -1315,6 +1437,7 @@ class _ModelFnWrapper(object): del loss # unused; required in function signature. inputs = dequeue_fn() features, labels = inputs.features_and_labels() + self._add_embedding_features(features) estimator_spec = self._verify_estimator_spec( self._call_model_fn(features, labels)) @@ -1330,12 +1453,22 @@ class _ModelFnWrapper(object): tracing_ops = [] if tensor_tracer.TensorTracer.is_enabled(): tt = tensor_tracer.TensorTracer() - loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(), loss, - self._ctx.num_replicas) + loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(), + loss, train_op, + self._ctx.num_replicas, + self._ctx.num_of_replicas_per_host, + self._ctx.num_hosts) + + if self._ctx.embedding_config is None: + apply_sparse_grads = [] + else: + tpu_embedding_ = self._ctx.embedding_config.tpu_embedding + apply_sparse_grads = [tpu_embedding_.generate_send_gradients_op()] # We must run train_op to update the variables prior to running the # outfeed. - with ops.control_dependencies([train_op]+tracing_ops): + with ops.control_dependencies([train_op] + tracing_ops + + apply_sparse_grads): host_call_outfeed_ops = [] if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access and estimator_spec.host_call is not None): @@ -1381,6 +1514,7 @@ class _ModelFnWrapper(object): """Evaluation step function for use inside a while loop.""" inputs = dequeue_fn() features, labels = inputs.features_and_labels() + self._add_embedding_features(features) tpu_estimator_spec = self._call_model_fn(features, labels) if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access @@ -1633,7 +1767,7 @@ class _OutfeedHostCall(object): 'Exception while calling %s: %s. It is likely the tensors ' '(%s[1]) do not match the ' 'function\'s arguments', name, e, name) - raise e + raise return ret def record(self, host_calls): @@ -1720,6 +1854,10 @@ class _OutfeedHostCall(object): dequeue_ops[j].append(item) # Deconstruct dequeue ops. + flat_dequeue_ops = [] + for l in dequeue_ops: + flat_dequeue_ops.extend(l) + dequeue_ops_by_name = {} pos = 0 for name in self._names: @@ -1727,6 +1865,14 @@ class _OutfeedHostCall(object): len(self._tensors[name])] pos += len(self._tensors[name]) + def _call_host_fn(fn, *args, **kw): + context = CatchInvalidHostcallFunctions() + context.Enter() + result = fn(*args, **kw) + context.Exit() + context.ExitResult(result) + return result + # It is assumed evaluation always happens on single host TPU system. So, # place all ops on tpu host if possible. # @@ -1739,24 +1885,39 @@ class _OutfeedHostCall(object): raise RuntimeError( 'All tensors outfed from TPU should preserve batch size ' 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) - # TODO(xiejw): Allow users to specify the axis for batch size - # dimension. - dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) + # TODO(xiejw): Make the specification of the outfeed combinaton + # function more explicit and well-documented. We may want to give the + # user the option of concatenating along any axis. + if (self._ctx.config.tpu_config.per_host_input_for_training is + tpu_config.InputPipelineConfig.BROADCAST): + # If the infeed is in BROADCAST mode (each core recieving the same + # input), then we assume that the cores also produce identical + # copies of the same output, and we simply take the output from + # the first core. This mode is used by Mesh-TensorFlow. + with ops.control_dependencies(dequeue_ops[i]): + dequeue_ops[i] = array_ops.identity(dequeue_ops[i][0]) + else: + # Assume that the input has been batch-split and that axis 0 of the + # output tensors represents the batch size. Concatenate along + # the axis 0 to re-combine the batch. + dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) if self._tensor_keys[name] is not None: # The user-provided eval_metrics[1] is a dict. dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops)) try: - ret[name] = self._host_fns[name](**dequeue_ops) + ret[name] = _call_host_fn(self._host_fns[name], **dequeue_ops) except TypeError as e: logging.warning( 'Exception while calling %s: %s. It is likely the tensors ' '(%s[1]) do not match the ' 'function\'s arguments', name, e, name) - raise e + raise else: - ret[name] = self._host_fns[name](*dequeue_ops) + ret[name] = _call_host_fn(self._host_fns[name], *dequeue_ops) + # force all dequeue operations to be run if not consumed by the host calls + ret['__force_dequeue'] = control_flow_ops.group(*flat_dequeue_ops) return ret @@ -2048,7 +2209,11 @@ class TPUEstimator(estimator_lib.Estimator): batch_axis=None, eval_on_tpu=True, export_to_tpu=True, - warm_start_from=None): + export_to_cpu=True, + warm_start_from=None, + experimental_exported_model_uses_all_cores=False, + experimental_export_device_assignment=False, + experimental_embedding_config_spec=None): """Constructs an `TPUEstimator` instance. Args: @@ -2091,12 +2256,29 @@ class TPUEstimator(estimator_lib.Estimator): eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. export_to_tpu: If True, `export_savedmodel()` exports a metagraph for - serving on TPU besides the one on CPU. + serving on TPU. Note that unsupported export modes such as EVAL will be + ignored. For those modes, only a CPU model will be exported. + Currently, export_to_tpu only supports PREDICT. + export_to_cpu: If True, `export_savedmodel()` exports a metagraph for + serving on CPU. warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a `tf.estimator.WarmStartSettings` object to fully configure warm-starting. If the string filepath is provided instead of a `WarmStartSettings`, then all variables are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. + experimental_exported_model_uses_all_cores: Whether to round-robin among + all cores visible to the host which is serving the saved model, or to + use a single core. This is a temporary flag to enable using all TPU + cores for inference with TPUPartitionedCall(). Once outside compilation + is supported in TPUPartitionedCall(), this flag will be enabled by + default. + experimental_export_device_assignment: Whether to include the device + assignment in the exported model. Doing so is useful in case of model + parallel inference but will tie the exported model to the TPU topology + used to export the model. + experimental_embedding_config_spec: Optional EmbeddingConfigSpec instance + to support using TPU embedding. IT IS STILL WORK IN PROGRESS, SO PLEASE + DO NOT USE. Raises: ValueError: `params` has reserved keys already. @@ -2158,9 +2340,19 @@ class TPUEstimator(estimator_lib.Estimator): # pylint: disable=protected-access self._ctx = tpu_context._get_tpu_context( self._config, train_batch_size, eval_batch_size, predict_batch_size, - use_tpu, eval_on_tpu) + use_tpu, eval_on_tpu, experimental_embedding_config_spec) + self._export_to_cpu = export_to_cpu self._export_to_tpu = export_to_tpu + self._experimental_exported_model_uses_all_cores = ( + experimental_exported_model_uses_all_cores) + self._experimental_export_device_assignment = ( + experimental_export_device_assignment) + if (experimental_exported_model_uses_all_cores and + experimental_export_device_assignment): + raise ValueError('experimental_exported_model_uses_all_cores and ' + 'experimental_export_device_assignment is not supported ' + 'at the same time.') self._is_input_fn_invoked = None self._rendezvous = {} @@ -2174,35 +2366,43 @@ class TPUEstimator(estimator_lib.Estimator): export_tags=None, check_variables=True): if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT: - raise NotImplementedError( - 'TPUEstimator only handles mode PREDICT for exporting ' - 'when `export_to_tpu` is `True`; ' - 'got {}.'.format(mode)) - - (super(TPUEstimator, self)._add_meta_graph_for_mode( - builder, - input_receiver_fn_map, - checkpoint_path, - save_variables, - mode=mode, - export_tags=export_tags, - check_variables=check_variables)) + logging.warning('TPUEstimator only handles mode PREDICT for exporting ' + 'when `export_to_tpu` is `True`; Mode {} will be ignored ' + 'for TPU.'.format(mode)) + + if not self._export_to_cpu and not self._export_to_tpu: + raise ValueError('One of export_to_cpu and export_to_tpu must be true.') - if self._export_to_tpu: + if self._export_to_cpu: + (super(TPUEstimator, self)._add_meta_graph_for_mode( + builder, + input_receiver_fn_map, + checkpoint_path, + save_variables, + mode=mode, + export_tags=export_tags, + check_variables=check_variables)) + + if self._export_to_tpu and mode == model_fn_lib.ModeKeys.PREDICT: input_receiver_fn_map = { _REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode] } export_tags = [tag_constants.SERVING, tag_constants.TPU] mode = _REWRITE_FOR_INFERENCE_MODE + # See b/110052256 for why `check_variables` is `False`. + if not self._export_to_cpu: + check_variables = save_variables = True + else: + check_variables = save_variables = False (super(TPUEstimator, self)._add_meta_graph_for_mode( builder, input_receiver_fn_map, checkpoint_path, - save_variables=False, + save_variables=save_variables, mode=mode, export_tags=export_tags, - check_variables=False)) + check_variables=check_variables)) def _call_model_fn(self, features, labels, mode, config): if mode == _REWRITE_FOR_INFERENCE_MODE: @@ -2217,6 +2417,88 @@ class TPUEstimator(estimator_lib.Estimator): raise ValueError('mode must be {}; ' 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) + computation, capture = self._build_computation_for_inference( + features, labels, mode, config) + tensors = call_computation( + computation, + experimental_exported_model_uses_all_cores=self + ._experimental_exported_model_uses_all_cores) + estimator_spec, export_outputs_dict, predictions_dict, none_indices = ( + capture.get()) + predictions_list = tensors[:len(predictions_dict)] + export_outputs_list_without_none = tensors[len(predictions_dict):] + + # Reinsert `None`s which we've taken out in + # `_build_computation_for_inference()`. + export_outputs_list = [] + while none_indices or export_outputs_list_without_none: + if none_indices and none_indices[0] == len(export_outputs_list): + export_outputs_list.append(None) + none_indices.pop(0) + else: + export_outputs_list.append(export_outputs_list_without_none.pop(0)) + + # Reconstruct `export_outputs` with updated tensors. + new_export_outputs_dict = nest.pack_sequence_as(export_outputs_dict, + export_outputs_list) + export_outputs = estimator_spec.export_outputs + new_export_outputs = collections.OrderedDict( + (k, _clone_export_output_with_tensors(export_outputs[k], v)) + for k, v in six.iteritems(new_export_outputs_dict)) + # Reconstruct `predictions` with updated tensors. + new_predictions = nest.pack_sequence_as(predictions_dict, predictions_list) + if (len(new_predictions) == 1 and + _KEY_WHEN_PREDICTIONS_IS_A_TENSOR in new_predictions): + new_predictions = new_predictions[_KEY_WHEN_PREDICTIONS_IS_A_TENSOR] + + return estimator_spec._replace( + export_outputs=new_export_outputs, predictions=new_predictions) + + def _build_computation_for_inference(self, features, labels, mode, config): + capture = _CapturedObject() + + def computation(): + """Computation to be passed to `TPUPartitionedCall()`.""" + tpu_computation, tpu_capture = self._build_tpu_computation_for_inference( + features, labels, mode, config) + + if self._experimental_export_device_assignment: + # Export the device assignment as part of the model. This is useful for + # model parallel usecases where the model relies on the mapping between + # logical and physical devices. + with self._ctx.with_mode(mode) as ctx: + device_assignment = ctx.device_assignment + else: + device_assignment = None + tensors_on_cpu = tpu.rewrite_for_inference( + tpu_computation, device_assignment=device_assignment) + (estimator_spec, export_outputs_dict, export_outputs_list, + predictions_dict) = ( + tpu_capture.get()) + predictions_list = tensors_on_cpu[:len(predictions_dict)] + export_outputs_tpu_on_cpu_list = tensors_on_cpu[len(predictions_dict):] + + # Reconstruct tensors used in export_outputs, with TPU tensors replaced + # with their CPU counterpart returned from `rewrite_for_inference()`. + # `function.Defun()` does not like `None`s in return values, so we leave + # `None`s out but record their positions for later reconstruction. + export_outputs_list_without_none = [] + none_indices = [] + for i, t in enumerate(export_outputs_list): + if t is None: + none_indices.append(i) + else: + export_outputs_list_without_none.append( + export_outputs_tpu_on_cpu_list.pop(0)) + + capture.capture((estimator_spec, export_outputs_dict, predictions_dict, + none_indices)) + return predictions_list + export_outputs_list_without_none + + return computation, capture + + def _build_tpu_computation_for_inference(self, features, labels, mode, + config): capture = _CapturedObject() def computation(): @@ -2237,46 +2519,30 @@ class TPUEstimator(estimator_lib.Estimator): # We pick the TPU tensors out from `export_output` and later return them # from `computation` for rewriting. - tensors_dict = collections.OrderedDict( + export_outputs_dict = collections.OrderedDict( (k, _export_output_to_tensors(v)) for k, v in six.iteritems(estimator_spec.export_outputs)) - tensors = nest.flatten(tensors_dict) - tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)] - - # We cannot return anything other than `tpu_tensors` here so we capture - # the rest for later use. - capture.capture((estimator_spec, tensors_dict, tensors)) - return tpu_tensors - - tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation) - estimator_spec, tensors_dict, tensors = capture.get() - - # Reconstruct `tensors`, but with `tpu_tensors` replaced with - # `tpu_tensors_on_cpu`. - new_tensors = [] - for t in tensors: - if _is_tpu_tensor(t): - new_tensors.append(tpu_tensors_on_cpu.pop(0)) - elif t is None: - new_tensors.append(None) + export_outputs_list = nest.flatten(export_outputs_dict) + export_outputs_tpu_list = [ + t for t in export_outputs_list if t is not None + ] + + if isinstance(estimator_spec.predictions, dict): + predictions_dict = collections.OrderedDict( + (k, v) for k, v in six.iteritems(estimator_spec.predictions)) else: - # Only fetching `tpu_tensors_on_cpu` does not trigger - # TPU computation and blocks, so we add the control dependency here. - control_inputs = ( - tpu_tensors_on_cpu if _is_iterable(tpu_tensors_on_cpu) else - (tpu_tensors_on_cpu,)) - with ops.control_dependencies(control_inputs): - new_tensors.append(array_ops.identity(t)) - - # Reconstruct `tensors_dict`. - new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) - # Reconstruct `export_outputs`. - export_outputs = estimator_spec.export_outputs - new_export_outputs = collections.OrderedDict( - (k, _clone_export_output_with_tensors(export_outputs[k], v)) - for k, v in six.iteritems(new_tensors_dict)) + predictions_dict = { + _KEY_WHEN_PREDICTIONS_IS_A_TENSOR: estimator_spec.predictions + } + predictions_list = nest.flatten(predictions_dict) + + # We cannot return everything we want through the return values, so + # capture the rest here for later use. + capture.capture((estimator_spec, export_outputs_dict, export_outputs_list, + predictions_dict)) + return predictions_list + export_outputs_tpu_list - return estimator_spec._replace(export_outputs=new_export_outputs) + return computation, capture def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -2494,7 +2760,11 @@ class TPUEstimator(estimator_lib.Estimator): if self._log_every_n_steps is not None: examples_hook = ExamplesPerSecondHook( ctx.global_batch_size, - output_dir=self.model_dir, + # pylint:disable=g-long-ternary + output_dir=(self.model_dir + if not config or config.save_summary_steps + else None), + # pylint:enable=g-long-ternary every_n_steps=self._log_every_n_steps) if ctx.is_running_on_cpu(is_export_mode=is_export_mode): @@ -2511,6 +2781,13 @@ class TPUEstimator(estimator_lib.Estimator): assert callable(features), '`input_fn` is not callable.' input_fn = features + tpu_init_ops = [] + if ctx.embedding_config: + tpu_init_ops.extend(ctx.embedding_config.tpu_embedding.init_ops) + embedding_variables_and_ops = ( + ctx.embedding_config.tpu_embedding.create_variables_and_ops()) + tpu_init_ops.extend(embedding_variables_and_ops.load_ops) + input_holders = _InputPipeline(input_fn, batch_axis, ctx) enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = ( input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) @@ -2523,7 +2800,7 @@ class TPUEstimator(estimator_lib.Estimator): graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) if mode == model_fn_lib.ModeKeys.TRAIN: - loss, host_call, scaffold, training_hooks = ( + compile_op, loss, host_call, scaffold, training_hooks = ( _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) host_ops = host_call.create_tpu_hostcall() if host_ops is None: @@ -2558,10 +2835,13 @@ class TPUEstimator(estimator_lib.Estimator): ctx, enqueue_ops, host_ops, + tpu_compile_op=compile_op, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), rendezvous=self._rendezvous[mode], - ), + master=self._config.master, + session_config=self._session_config, + tpu_init_ops=tpu_init_ops), InstallSignalHandlerHook() ]) if self._log_every_n_steps is not None: @@ -2598,6 +2878,9 @@ class TPUEstimator(estimator_lib.Estimator): with ops.control_dependencies([loss]): update_ops = _sync_variables_ops(ctx) + if ctx.embedding_config: + update_ops.extend(embedding_variables_and_ops.retrieve_ops) + # Validate the TPU training graph to catch basic errors _validate_tpu_training_graph() @@ -2613,8 +2896,8 @@ class TPUEstimator(estimator_lib.Estimator): scaffold=scaffold) if mode == model_fn_lib.ModeKeys.EVAL: - total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) + compile_op, total_loss, host_calls, scaffold, eval_hooks = ( + _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) iterations_per_loop_var = _create_or_get_iterations_per_loop() mean_loss = math_ops.div( total_loss, @@ -2661,9 +2944,13 @@ class TPUEstimator(estimator_lib.Estimator): ctx, enqueue_ops, eval_update_ops + host_ops, + tpu_compile_op=compile_op, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode]), + rendezvous=self._rendezvous[mode], + master=self._config.evaluation_master, + session_config=self._session_config, + tpu_init_ops=tpu_init_ops) ] + input_hooks if eval_hooks: @@ -2679,7 +2966,7 @@ class TPUEstimator(estimator_lib.Estimator): # Predict assert mode == model_fn_lib.ModeKeys.PREDICT - (dummy_predict_op, host_calls, + (compile_op, dummy_predict_op, host_calls, scaffold, prediction_hooks) = _predict_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) with ops.control_dependencies([dummy_predict_op]): @@ -2735,7 +3022,10 @@ class TPUEstimator(estimator_lib.Estimator): hooks = [ _StoppingPredictHook(scalar_stopping_signal), TPUInfeedOutfeedSessionHookForPrediction( - ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]), + ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode], + tpu_compile_op=compile_op, + master=self._config.master, + session_config=self._session_config), ] + input_hooks if prediction_hooks: @@ -2750,17 +3040,6 @@ class TPUEstimator(estimator_lib.Estimator): return _model_fn -def _is_tpu_tensor(tensor): - if not isinstance(tensor, ops.Tensor): - return False - try: - tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access - except ValueError: - return True - else: - return False - - def _export_output_to_tensors(export_output): """Get a list of `Tensors` used in `export_output`. @@ -2832,15 +3111,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): return training_loop.repeat(iterations_per_loop_var, single_tpu_eval_step, [_ZERO_LOSS]) - (loss,) = tpu.shard( + (compile_op, loss,) = tpu.split_compile_and_shard( multi_tpu_eval_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + loss = loss[0] scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_calls, scaffold, captured_eval_hooks.get() + return compile_op, loss, host_calls, scaffold, captured_eval_hooks.get() def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): @@ -2855,15 +3135,16 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): return training_loop.repeat(iterations_per_loop_var, single_tpu_train_step, [_INITIAL_LOSS]) - (loss,) = tpu.shard( + (compile_op, loss,) = tpu.split_compile_and_shard( multi_tpu_train_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + loss = loss[0] scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_call, scaffold, captured_training_hooks.get() + return compile_op, loss, host_call, scaffold, captured_training_hooks.get() def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): @@ -2883,15 +3164,17 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): cond, single_tpu_predict_step, inputs=inputs, name=b'loop') return outputs - (dummy_predict_op,) = tpu.shard( + (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard( multi_tpu_predict_steps_on_single_shard, inputs=[], num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) + dummy_predict_op = dummy_predict_op[0] scaffold = _get_scaffold(captured_scaffold_fn) - return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() + return (compile_op, dummy_predict_op, host_calls, scaffold, + captured_predict_hooks.get()) def _wrap_computation_in_while_loop(device, op_fn): @@ -3081,7 +3364,7 @@ class _Inputs(object): The initializer must be run before calling `features_and_labels`. """ - self._iterator = self._dataset.make_initializable_iterator() + self._iterator = dataset_ops.make_initializable_iterator(self._dataset) return self._iterator.initializer def features_and_labels(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py index 3786e52b949dfac8c1587d1ea3041b625f00183f..e3ea983abfd24d03c964fbc647b56262e15e0a96 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py @@ -21,8 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.tpu.python.tpu import tpu_estimator -from tensorflow.python import data as dataset_lib from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -34,10 +34,10 @@ def make_input_fn(num_samples): def input_fn(params): batch_size = params['batch_size'] - da1 = dataset_lib.Dataset.from_tensor_slices(a) - da2 = dataset_lib.Dataset.from_tensor_slices(b) + da1 = dataset_ops.Dataset.from_tensor_slices(a) + da2 = dataset_ops.Dataset.from_tensor_slices(b) - dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset_ops.Dataset.zip((da1, da2)) dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb}) dataset = dataset.batch(batch_size) return dataset @@ -50,10 +50,10 @@ def make_input_fn_with_labels(num_samples): def input_fn(params): batch_size = params['batch_size'] - da1 = dataset_lib.Dataset.from_tensor_slices(a) - da2 = dataset_lib.Dataset.from_tensor_slices(b) + da1 = dataset_ops.Dataset.from_tensor_slices(a) + da2 = dataset_ops.Dataset.from_tensor_slices(b) - dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset_ops.Dataset.zip((da1, da2)) dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb)) dataset = dataset.batch(batch_size) return dataset @@ -71,7 +71,7 @@ class TPUEstimatorStoppingSignalsTest(test.TestCase): with ops.Graph().as_default(): dataset = input_fn(params) - features = dataset.make_one_shot_iterator().get_next() + features = dataset_ops.make_one_shot_iterator(dataset).get_next() # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. self.assertIsNone(features['a'].shape.as_list()[0]) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index ec682e5829c4df536a043334b74200f0b6259df3..d66ecfcf4a56b8da1c2d2f518bebe4baa76b315e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -52,6 +52,7 @@ def _query_tpu_system_metadata(master_address, cluster_def=None, devices = [] device_dict = collections.defaultdict(list) + # TODO(b/120564445): Replace with standard library for retries. retry_count = 1 while True: logging.info('Querying Tensorflow master (%s) for TPU system metadata.', diff --git a/tensorflow/contrib/tpu/tpu_estimator.md b/tensorflow/contrib/tpu/tpu_estimator.md index b6514e19dc92fe4c7cdcdb6582a7c0ad5ad573d5..552febd80bd35b37a95cdaaf8d5923278311ac8e 100644 --- a/tensorflow/contrib/tpu/tpu_estimator.md +++ b/tensorflow/contrib/tpu/tpu_estimator.md @@ -89,12 +89,9 @@ handle training: dataset = tf.data.TFRecordDataset( filename, buffer_size=FLAGS.dataset_reader_buffer_size) - dataset = dataset.map(parser).cache().repeat().batch(batch_size) - images, labels = dataset.make_one_shot_iterator().get_next() - # set_shape to give inputs statically known shapes. - images.set_shape([batch_size, 28 * 28]) - labels.set_shape([batch_size]) - return images, labels + dataset = dataset.map(parser).cache().repeat().batch( + batch_size, drop_remainder=True) + return dataset return input_fn diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc index 76cb5531cd0bc3a375d1434c31fa14a9d7f42476..d98e0b7a5ed52c00a8cf2b1a1bbc53f1b1cd28c7 100644 --- a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc +++ b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc @@ -134,12 +134,16 @@ Status GetGradientAccumulationSupport(OptimizationAlgorithm alg, } } namespace { -// Make a normal state variable specification. +// Make a normal state variable specification. Please refer to +// //third_party/tensorflow/contrib/tpu/proto/optimization_parameters.proto +// (StateVariableSpecification message) for instructions on how to set the +// padding_initial_value field. StateVariableSpecification MakeStandardStateVariableSpecification( - const string& name) { + const string& name, double padding_initial_value) { StateVariableSpecification result; result.set_name(name); - result.mutable_user_defined(); + result.mutable_user_defined()->set_padding_initial_value( + padding_initial_value); return result; } } // namespace @@ -149,14 +153,14 @@ Status GetOptimizationAlgorithmStateVariables( std::vector* state_variables) { // The first parameter set is always the weights themselves. state_variables->push_back( - MakeStandardStateVariableSpecification("parameters")); + MakeStandardStateVariableSpecification("parameters", 0.0)); // The order of the returned parameters needs to match the offsets used by // the algorithm implementations in test_util.cc and // address_handler_program_creator.cc. switch (alg) { case OptimizationAlgorithm::kAdagrad: { state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators")); + MakeStandardStateVariableSpecification("accumulators", 0.1)); break; } case OptimizationAlgorithm::kStochasticGradientDescent: { @@ -165,53 +169,58 @@ Status GetOptimizationAlgorithmStateVariables( } case OptimizationAlgorithm::kFtrl: { state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators")); + MakeStandardStateVariableSpecification("accumulators", 0.1)); state_variables->push_back( - MakeStandardStateVariableSpecification("linears")); + MakeStandardStateVariableSpecification("linears", 0.0)); break; } case OptimizationAlgorithm::kAdam: { state_variables->push_back( - MakeStandardStateVariableSpecification("momenta")); + MakeStandardStateVariableSpecification("momenta", 0.0)); state_variables->push_back( - MakeStandardStateVariableSpecification("velocities")); + MakeStandardStateVariableSpecification("velocities", 0.0)); break; } case OptimizationAlgorithm::kMomentum: { state_variables->push_back( - MakeStandardStateVariableSpecification("momenta")); + MakeStandardStateVariableSpecification("momenta", 0.0)); break; } case OptimizationAlgorithm::kRmsProp: { - state_variables->push_back(MakeStandardStateVariableSpecification("ms")); - state_variables->push_back(MakeStandardStateVariableSpecification("mom")); + state_variables->push_back( + MakeStandardStateVariableSpecification("ms", 1.0)); + state_variables->push_back( + MakeStandardStateVariableSpecification("mom", 0.0)); break; } case OptimizationAlgorithm::kCenteredRmsProp: { - state_variables->push_back(MakeStandardStateVariableSpecification("ms")); - state_variables->push_back(MakeStandardStateVariableSpecification("mom")); - state_variables->push_back(MakeStandardStateVariableSpecification("mg")); + state_variables->push_back( + MakeStandardStateVariableSpecification("ms", 1.0)); + state_variables->push_back( + MakeStandardStateVariableSpecification("mom", 0.0)); + state_variables->push_back( + MakeStandardStateVariableSpecification("mg", 0.0)); break; } case OptimizationAlgorithm::kMdlAdagradLight: { state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators")); + MakeStandardStateVariableSpecification("accumulators", 0.1)); state_variables->push_back( - MakeStandardStateVariableSpecification("weights")); + MakeStandardStateVariableSpecification("weights", 0.0)); state_variables->push_back( - MakeStandardStateVariableSpecification("benefits")); + MakeStandardStateVariableSpecification("benefits", 0.0)); break; } case OptimizationAlgorithm::kAdadelta: { state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators")); + MakeStandardStateVariableSpecification("accumulators", 0.0)); state_variables->push_back( - MakeStandardStateVariableSpecification("updates")); + MakeStandardStateVariableSpecification("updates", 0.0)); break; } case OptimizationAlgorithm::kProximalAdagrad: { state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators")); + MakeStandardStateVariableSpecification("accumulators", 0.1)); break; } case OptimizationAlgorithm::PARAMETERS_NOT_SET: { diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 00295f57f60858db5234ce28cc643ea9eee44daa..5bc4c3b88efd641b6f17a54753a29b0603c2b98c 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -26,7 +26,6 @@ py_library( "python/training/resample.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", - "python/training/tensor_queue_dataset.py", "python/training/training.py", "python/training/tuner.py", ], @@ -265,9 +264,9 @@ py_test( py_test( name = "training_test", - size = "large", + size = "medium", srcs = ["python/training/training_test.py"], - shard_count = 3, + shard_count = 8, srcs_version = "PY2AND3", tags = ["notsan"], deps = [ @@ -287,28 +286,6 @@ py_test( ], ) -py_test( - name = "tensor_queue_dataset_test", - size = "large", - srcs = ["python/training/tensor_queue_dataset_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":training_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data", - "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base", - "//third_party/py/numpy", - ], -) - tf_proto_library( name = "protos_all", srcs = glob(["**/*.proto"]), diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py index 3547e71184ec2b99163ea4247c01d24487811b47..87ce57ef060a0eb9383248255713421c14988416 100644 --- a/tensorflow/contrib/training/__init__.py +++ b/tensorflow/contrib/training/__init__.py @@ -59,8 +59,6 @@ from tensorflow.contrib.training.python.training.hparam import * from tensorflow.contrib.training.python.training.resample import * from tensorflow.contrib.training.python.training.sampling_ops import * from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import * -from tensorflow.contrib.training.python.training.tensor_queue_dataset import enqueue_in_queue_dataset -from tensorflow.contrib.training.python.training.tensor_queue_dataset import prepend_from_queue_and_padded_batch_dataset from tensorflow.contrib.training.python.training.training import add_gradients_summaries from tensorflow.contrib.training.python.training.training import clip_gradient_norms from tensorflow.contrib.training.python.training.training import clip_gradient_norms_fn @@ -79,7 +77,6 @@ _allowed_symbols = [ 'FeedingQueueRunner', 'get_or_create_eval_step', 'StopAfterNEvalsHook', 'SummaryAtEndHook', 'wait_for_new_checkpoint', 'add_gradients_summaries', 'clip_gradient_norms', 'clip_gradient_norms_fn', 'create_train_op', - 'multiply_gradients', 'enqueue_in_queue_dataset', - 'prepend_from_queue_and_padded_batch_dataset', 'train'] + 'multiply_gradients', 'train'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 3beb7bfe3048a8f0294f7e9149b5a07b5fcc7d17..27f0d9b2e38c433d4fb4573285ecb8c9946112e8 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -187,7 +187,7 @@ def _cast_to_type_if_compatible(name, param_type, value): return param_type(value) -def parse_values(values, type_map): +def parse_values(values, type_map, ignore_unknown=False): """Parses hyperparameter values from a string into a python map. `values` is a string containing comma-separated `name=value` pairs. @@ -233,6 +233,9 @@ def parse_values(values, type_map): type T if either V has type T, or V is a list of elements of type T. Hence, for a multidimensional parameter 'x' taking float values, 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. + ignore_unknown: Bool. Whether values that are missing a type in type_map + should be ignored. If set to True, a ValueError will not be raised for + unknown hyperparameter type. Returns: A python map mapping each name to either: @@ -260,6 +263,8 @@ def parse_values(values, type_map): m_dict = m.groupdict() name = m_dict['name'] if name not in type_map: + if ignore_unknown: + continue raise ValueError('Unknown hyperparameter type for %s' % name) type_ = type_map[name] @@ -494,6 +499,7 @@ class HParams(object): value: New value of the hyperparameter. Raises: + KeyError: If the hyperparameter doesn't exist. ValueError: If there is a type mismatch. """ param_type, is_list = self._hparam_types[name] @@ -512,6 +518,8 @@ class HParams(object): def del_hparam(self, name): """Removes the hyperparameter with key 'name'. + Does nothing if it isn't present. + Args: name: Name of the hyperparameter. """ @@ -520,19 +528,20 @@ class HParams(object): del self._hparam_types[name] def parse(self, values): - """Override hyperparameter values, parsing new values from a string. + """Override existing hyperparameter values, parsing new values from a string. See parse_values for more detail on the allowed format for values. Args: - values: String. Comma separated list of `name=value` pairs where - 'value' must follow the syntax described above. + values: String. Comma separated list of `name=value` pairs where 'value' + must follow the syntax described above. Returns: The `HParams` instance. Raises: - ValueError: If `values` cannot be parsed. + ValueError: If `values` cannot be parsed or a hyperparameter in `values` + doesn't exist. """ type_map = dict() for name, t in self._hparam_types.items(): @@ -543,7 +552,7 @@ class HParams(object): return self.override_from_dict(values_map) def override_from_dict(self, values_dict): - """Override hyperparameter values, parsing new values from a dictionary. + """Override existing hyperparameter values, parsing new values from a dictionary. Args: values_dict: Dictionary of name:value pairs. @@ -552,6 +561,7 @@ class HParams(object): The `HParams` instance. Raises: + KeyError: If a hyperparameter in `values_dict` doesn't exist. ValueError: If `values_dict` cannot be parsed. """ for name, value in values_dict.items(): @@ -591,7 +601,7 @@ class HParams(object): sort_keys=sort_keys) def parse_json(self, values_json): - """Override hyperparameter values, parsing new values from a json object. + """Override existing hyperparameter values, parsing new values from a json object. Args: values_json: String containing a json object of name:value pairs. @@ -600,6 +610,7 @@ class HParams(object): The `HParams` instance. Raises: + KeyError: If a hyperparameter in `values_json` doesn't exist. ValueError: If `values_json` cannot be parsed. """ values_map = json.loads(values_json) diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 660c97f25e8458c345c8914bcaf98f37d047e50e..a990e04711ce68bd928a508484f0d6f657dd2f8c 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -216,6 +216,14 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {1: 10}) + def testParseValuesWithIndexAssigment1_IgnoreUnknown(self): + """Assignment to an index position.""" + parse_dict = hparam.parse_values( + 'arr[1]=10,b=5', {'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 1) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {1: 10}) + def testParseValuesWithIndexAssigment2(self): """Assignment to multiple index positions.""" parse_dict = hparam.parse_values('arr[0]=10,arr[5]=20', {'arr': int}) @@ -223,6 +231,14 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['arr'], dict)) self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20}) + def testParseValuesWithIndexAssigment2_IgnoreUnknown(self): + """Assignment to multiple index positions.""" + parse_dict = hparam.parse_values( + 'arr[0]=10,arr[5]=20,foo=bar', {'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 1) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {0: 10, 5: 20}) + def testParseValuesWithIndexAssigment3(self): """Assignment to index positions in multiple names.""" parse_dict = hparam.parse_values('arr[0]=10,arr[1]=20,L[5]=100,L[10]=200', @@ -234,6 +250,17 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['L'], dict)) self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200}) + def testParseValuesWithIndexAssigment3_IgnoreUnknown(self): + """Assignment to index positions in multiple names.""" + parse_dict = hparam.parse_values( + 'arr[0]=10,C=5,arr[1]=20,B[0]=kkk,L[5]=100,L[10]=200', + {'arr': int, 'L': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 2) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {0: 10, 1: 20}) + self.assertTrue(isinstance(parse_dict['L'], dict)) + self.assertDictEqual(parse_dict['L'], {5: 100, 10: 200}) + def testParseValuesWithIndexAssigment4(self): """Assignment of index positions and scalars.""" parse_dict = hparam.parse_values('x=10,arr[1]=20,y=30', @@ -246,6 +273,17 @@ class HParamsTest(test.TestCase): self.assertEqual(parse_dict['x'], 10) self.assertEqual(parse_dict['y'], 30) + def testParseValuesWithIndexAssigment4_IgnoreUnknown(self): + """Assignment of index positions and scalars.""" + parse_dict = hparam.parse_values( + 'x=10,foo[0]=bar,arr[1]=20,zzz=78,y=30', + {'x': int, 'y': int, 'arr': int}, ignore_unknown=True) + self.assertEqual(len(parse_dict), 3) + self.assertTrue(isinstance(parse_dict['arr'], dict)) + self.assertDictEqual(parse_dict['arr'], {1: 20}) + self.assertEqual(parse_dict['x'], 10) + self.assertEqual(parse_dict['y'], 30) + def testParseValuesWithIndexAssigment5(self): """Different variable types.""" parse_dict = hparam.parse_values('a[0]=5,b[1]=true,c[2]=abc,d[3]=3.14', { @@ -264,24 +302,55 @@ class HParamsTest(test.TestCase): self.assertTrue(isinstance(parse_dict['d'], dict)) self.assertDictEqual(parse_dict['d'], {3: 3.14}) + def testParseValuesWithIndexAssigment5_IgnoreUnknown(self): + """Different variable types.""" + parse_dict = hparam.parse_values( + 'a[0]=5,cc=4,b[1]=true,c[2]=abc,mm=2,d[3]=3.14', + {'a': int, 'b': bool, 'c': str, 'd': float}, + ignore_unknown=True) + self.assertEqual(set(parse_dict.keys()), {'a', 'b', 'c', 'd'}) + self.assertTrue(isinstance(parse_dict['a'], dict)) + self.assertDictEqual(parse_dict['a'], {0: 5}) + self.assertTrue(isinstance(parse_dict['b'], dict)) + self.assertDictEqual(parse_dict['b'], {1: True}) + self.assertTrue(isinstance(parse_dict['c'], dict)) + self.assertDictEqual(parse_dict['c'], {2: 'abc'}) + self.assertTrue(isinstance(parse_dict['d'], dict)) + self.assertDictEqual(parse_dict['d'], {3: 3.14}) + def testParseValuesWithBadIndexAssigment1(self): """Reject assignment of list to variable type.""" with self.assertRaisesRegexp(ValueError, r'Assignment of a list to a list index.'): hparam.parse_values('arr[1]=[1,2,3]', {'arr': int}) + def testParseValuesWithBadIndexAssigment1_IgnoreUnknown(self): + """Reject assignment of list to variable type.""" + with self.assertRaisesRegexp(ValueError, + r'Assignment of a list to a list index.'): + hparam.parse_values( + 'arr[1]=[1,2,3],c=8', {'arr': int}, ignore_unknown=True) + def testParseValuesWithBadIndexAssigment2(self): """Reject if type missing.""" with self.assertRaisesRegexp(ValueError, r'Unknown hyperparameter type for arr'): hparam.parse_values('arr[1]=5', {}) + def testParseValuesWithBadIndexAssigment2_IgnoreUnknown(self): + """Ignore missing type.""" + hparam.parse_values('arr[1]=5', {}, ignore_unknown=True) + def testParseValuesWithBadIndexAssigment3(self): """Reject type of the form name[index].""" with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter type for arr'): hparam.parse_values('arr[1]=1', {'arr[1]': int}) + def testParseValuesWithBadIndexAssigment3_IgnoreUnknown(self): + """Ignore type of the form name[index].""" + hparam.parse_values('arr[1]=1', {'arr[1]': int}, ignore_unknown=True) + def testWithReusedVariables(self): with self.assertRaisesRegexp(ValueError, 'Multiple assignments to variable \'x\''): diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py deleted file mode 100644 index 8896a95327a4cb609a9a78412afa68b316a3131e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Python wrappers for Datasets and Iterators.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import convert -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.util import nest as tf_nest - - -class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.UnaryDataset): - """A `Dataset` that prepends a queue to another `Dataset`. - - A vector of handles to the queue is returned as the first component of - the associated iterator. This vector can be passed to - `enqueue_in_queue_dataset` to add new elements to the queue. - """ - - def __init__(self, input_dataset, batch_size, padded_shapes, padding_values): - """Initialize `PrependFromQueueAndPaddedBatchDataset`.""" - super(_PrependFromQueueAndPaddedBatchDataset, self).__init__(input_dataset) - if sparse.any_sparse(input_dataset.output_classes): - raise TypeError( - "Batching of padded sparse tensors is not currently supported") - self._input_dataset = input_dataset - self._batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - if padded_shapes is None: - self._padded_shapes = nest.map_structure( - convert.partial_shape_to_tensor, input_dataset.output_shapes) - else: - self._padded_shapes = nest.map_structure_up_to( - input_dataset.output_shapes, convert.partial_shape_to_tensor, - padded_shapes) - # pylint: disable=protected-access - padding_values = ( - padding_values if padding_values is not None else - dataset_ops._default_padding(input_dataset)) - self._padding_values = nest.map_structure_up_to( - input_dataset.output_shapes, dataset_ops._padding_value_to_tensor, - padding_values, input_dataset.output_types) - # pylint: enable=protected-access - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_dataset_ops.prepend_from_queue_and_padded_batch_dataset( - self._input_dataset._as_variant_tensor(), - batch_size=self._batch_size, - padded_shapes=[ - ops.convert_to_tensor(s, dtype=dtypes.int64) - for s in nest.flatten(self._padded_shapes) - ], - padding_values=nest.flatten(self._padding_values), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) - # pylint: enable=protected-access - - @property - def output_classes(self): - return (ops.Tensor, self._input_dataset.output_classes) - - def _as_batch_shape(self, shape_like): - return tensor_shape.vector(None).concatenate( - tensor_util.constant_value_as_shape(shape_like)) - - @property - def output_shapes(self): - # First output is a variant representing the Queue - return (tensor_shape.vector(None), - nest.map_structure(self._as_batch_shape, self._padded_shapes)) - - @property - def output_types(self): - # First output is a variant representing the Queue - return (dtypes.variant, self._input_dataset.output_types) - - -def prepend_from_queue_and_padded_batch_dataset(batch_size, - padding_values=None, - padded_shapes=None): - """A transformation that prepends a queue to a `Dataset` and batches results. - - A vector of handles to the queue is returned as the first component of the - associated iterator. This vector can be passed to `enqueue_in_queue_dataset` - to add new elements to the queue. - - Below is an example of how this dataset might be used to split incoming - variable-length sequences into "head" and "rest" parts, where "rest" parts - are re-enqueued back into the dataset. A more realistic example would - perform some calculation on the "head" and modify some components of "rest" - with the result (before re-enqueueing). - - ```python - dataset = tf.data.Dataset.from_tensor_slices([2*x for x in range(10)]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map(lambda count: (count, tf.ones((count,)))) - # Emit a queue we can prepend to, and counts/values as padded batch. - dataset = dataset.apply( - tf.contrib.training.prepend_from_queue_and_padded_batch_dataset( - batch_size=10)) - dataset = dataset.prefetch(1) - - iterator = dataset.make_one_shot_iterator() - queue, (count, padded_value) = iterator.get_next() - - # Split the padded_value into two pieces: head and rest - rest_indices = tf.squeeze(tf.where(count > 3), axis=1) - bound = tf.minimum(3, tf.reduce_max(count)) - value_head = padded_value[:, :bound] - count_rest = tf.gather(count - 3, rest_indices) - value_rest = tf.gather(padded_value[:, bound:], rest_indices) - queue_rest = tf.gather(queue, rest_indices) - enqueue_rest_op = tf.contrib.training.enqueue_in_queue_dataset( - queue_rest, (count_rest, value_rest)) - with tf.control_dependencies([enqueue_rest_op]): - calculation = fn(value_head) - - while True: # Will raise OutOfRange when finished with all pieces. - session.run(calculation) - ``` - - Args: - batch_size: `int64` scalar tensor. The batch size to use when performing - padded batching. - padding_values: (optional) Nested tuple of scalar tensors. If provided, - the structure and dtypes of padding_values should match that of - incoming dataset's `output_types`. - padded_shapes: (optional) Nested tuple of `int64` vector tensors. - If provided, the structure must match that of the incoming dataset's - `output_types`. If not provided, the incoming dataset's `output_shapes` - is used. Any unknown (`None` or `-1`) dimensions in the shapes are - treated as being unique per-batch: for each batch time, an unknown - dimension is replaced with the maximum given value of this dimension - across all tensors for the given component in the batch. - - Returns: - A `Dataset` transformation function, which can be passed to - `tf.data.Dataset.apply`. - """ - - def _apply_fn(dataset): - return _PrependFromQueueAndPaddedBatchDataset( - dataset, - batch_size=batch_size, - padding_values=padding_values, - padded_shapes=padded_shapes) - - return _apply_fn - - -def enqueue_in_queue_dataset(queue, components): - """Enqueue components into queue from `PrependFromQueueAndPaddedBatchDataset`. - - The components' dtypes and shapes must be compatible with the `output_shapes` - attribute of the `dataset` created by - `prepend_from_queue_and_padded_batch_dataset`. This operation supports both - non-batched and batched modes. - - For more details, see the example in the docstring for - `prepend_from_queue_and_padded_batch_dataset`. - - Args: - queue: `variant` scalar or vector tensor. - The tensor emitted by the first component of the iterator associated with - `prepend_from_queue_and_padded_batch_dataset`. If this is a scalar, - then the `components` input tensors should not have a prepended batch - dimension. - components: Nested tuple of tensors, each with a leading batch dimension - if `queue` is a vector. The structure, dtypes, and shapes - (excluding batch dimension) must match the nested tuples - `dataset.output_types[1]` and `dataset.output_shapes[1]` (the non-queue - output types and shapes) of the `dataset` emitted by - the original `prepend_from_queue_and_padded_batch_dataset` call. - - Returns: - An `Operation` that enqueues `components` into the dataset(s) associated - with entries of `queue`. - """ - return gen_dataset_ops.enqueue_in_queue_dataset( - queue=queue, components=tf_nest.flatten(components)) diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py deleted file mode 100644 index c1657fec7bbe4a3227c3ea273b72176ac4066c50..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py +++ /dev/null @@ -1,355 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TensorQueueDataset.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import string_ops -from tensorflow.python.platform import test - - -class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): - - def testNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - self.assertEqual((dtypes.variant, dtypes.int32), dataset.output_types) - self.assertAllEqual(([None],) * 2, - [x.as_list() for x in dataset.output_shapes]) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertEqual([0], self.evaluate(value)) - self.assertEqual([1], self.evaluate(value)) - self.assertEqual([2], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertAllEqual([0, 1], self.evaluate(value)) - self.assertAllEqual([2], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedWithBiggerPaddingNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=2, padded_shapes=[3])) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertAllEqual([[0, 0, 0], [1, 0, 0]], self.evaluate(value)) - self.assertAllEqual([[2, 0, 0]], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedWithBiggerPaddingOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=1, padded_shapes=[3])) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - with self.cached_session() as sess: - self.assertAllEqual([[0, 0, 0]], sess.run(value)) - value_1, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[1, 0, 0]], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[-1, 0, 0]], value_2) - value_3 = sess.run(value) - self.assertAllEqual([[1, 0, 0]], value_3) - value_4, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[2, 0, 0]], value_4) - value_5 = sess.run(value) - self.assertAllEqual([[-2, 0, 0]], value_5) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - with self.cached_session() as sess: - self.assertEqual([0], sess.run(value)) - value_1, _ = sess.run([value, enqueue_negative]) - self.assertEqual([1], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertEqual([-1], value_2) - value_3 = sess.run(value) - self.assertEqual([1], value_3) - value_4, _ = sess.run([value, enqueue_negative]) - self.assertEqual([2], value_4) - value_5 = sess.run(value) - self.assertEqual([-2], value_5) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testBatchedOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]], - array_ops.expand_dims( - value[0], axis=0)) - with self.cached_session() as sess: - value_0, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([0, 1], value_0) - value_1, _ = sess.run([value, enqueue_zeroth]) - self.assertAllEqual([0, -1], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([0, 2], value_2) - self.assertAllEqual([0, -2], sess.run(value)) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testManyEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_many_more = [ - tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i) - for i in range(1000) - ] - with self.cached_session() as sess: - value_0, _ = sess.run((value, enqueue_many_more)) - self.assertEqual([0], value_0) - rest = [] - for _ in range(1000): - rest.append(sess.run(value)) - self.assertEquals([[100 + i] for i in range(1000)], sorted(rest)) - # Going back to the original input. - value_1, _ = sess.run((value, enqueue_many_more)) - self.assertEqual(1, value_1) - rest = [] - for _ in range(1000): - rest.append(sess.run(value)) - self.assertEquals([[100 + i + 1] for i in range(1000)], sorted(rest)) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testEnqueueWithPrefetch(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - # Prefetching will request additional values before they are - # available to the queue. - dataset = dataset.prefetch(buffer_size=3) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1) - with self.cached_session() as sess: - i = 0 - while i < 4: - received, _ = sess.run((value, enqueue)) - if received.size > 0: - self.assertAllEqual([i], received) - i += 1 - received_last = False - while True: - try: - received = sess.run(value) - if received.size > 0: - self.assertAllEqual([4], received) - received_last = True - except errors.OutOfRangeError: - break - self.assertTrue(received_last) - - def testDatasetWithPaddedShapeSmallerThanInputFails(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0, 0, 0]]).repeat(None) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=1, padded_shapes=[2])) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - with self.cached_session() as sess: - with self.assertRaisesOpError( - r"Incompatible input shapes at component 0 between " - r"input dataset this dataset: \[3\] vs. \[2\]"): - sess.run(value) - - def testEnqueueWithIncompatibleInputsFailsWithInformativeError(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0]).repeat(None) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - - enqueue_bad_structure = tqd.enqueue_in_queue_dataset( - queue_handle, (value, value)) - enqueue_bad_dtype = tqd.enqueue_in_queue_dataset(queue_handle, - np.array( - [1.0], - dtype=np.float32)) - enqueue_bad_shape_no_batch_dim = tqd.enqueue_in_queue_dataset( - queue_handle, ([1],)) - enqueue_bad_shape = tqd.enqueue_in_queue_dataset(queue_handle, - np.array( - [[1]], dtype=np.int32)) - - with self.cached_session() as sess: - with self.assertRaisesOpError( - "mismatched number of tensors. Queue expects 1 tensors but " - "tried to insert 2"): - sess.run(enqueue_bad_structure) - with self.assertRaisesOpError(r"Expected component 0 to have batched " - r"shape \[1,...\], but saw shape: \[\]"): - sess.run(enqueue_bad_shape_no_batch_dim) - with self.assertRaisesOpError( - r"mismatched shapes at component 0. Attempted to insert tensor " - r"with shape \[1\] but queue expected shape: \[\]"): - sess.run(enqueue_bad_shape) - with self.assertRaisesOpError( - r"mismatched dtypes at component 0. Attempted to insert tensor " - r"of type float but queue expected type: int32"): - sess.run(enqueue_bad_dtype) - - def testEnqueueWithPaddedBatchFailsWithInformativeError(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - with self.assertRaisesRegexp( - TypeError, r"Unable to create padding for field of type 'variant'"): - dataset.padded_batch(batch_size=10, padded_shapes=[1]) - - def testOneEnqueueWithPadding(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map( - lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) - # Emit a queue we can prepend to, and counts/values as padded - # batch. - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=3)) - - iterator = dataset.make_one_shot_iterator() - queue, (count, padded_value) = iterator.get_next() - - # Split the padded_value into two pieces: head and rest - rest_indices = array_ops.squeeze(array_ops.where(count > 2), axis=1) - bound = math_ops.minimum(2, math_ops.reduce_max(count)) - value_head = padded_value[:, :bound] - count_rest = array_ops.gather(count - 2, rest_indices) - value_rest = array_ops.gather(padded_value, rest_indices)[:, bound:] - queue_rest = array_ops.gather(queue, rest_indices) - enqueue_rest_op = tqd.enqueue_in_queue_dataset(queue_rest, - (count_rest, value_rest)) - with ops.control_dependencies([enqueue_rest_op]): - calc = array_ops.identity(value_head) - - with self.cached_session() as sess: - self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc)) - self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc)) - self.assertAllEqual([[6, 6]], sess.run(calc)) - self.assertAllEqual([[6, 6]], sess.run(calc)) - # Get some final batches due to prefetching. - for _ in range(3): - try: - self.assertAllEqual( - np.empty(shape=(0, 0), dtype=np.int32), sess.run(calc)) - except errors.OutOfRangeError as e: - self.assertTrue(str(e).startswith("End of sequence")) - - def testNonstandardPadding(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map( - lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) - # Emit a queue we can prepend to, and counts/values as padded - # batch. - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=3, padding_values=( - 0, - -1, - ))) - - iterator = dataset.make_one_shot_iterator() - _, (unused_count, padded_value) = iterator.get_next() - - with self.cached_session() as sess: - self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]], - sess.run(padded_value)) - self.assertAllEqual([[6] * 6], sess.run(padded_value)) - with self.assertRaisesOpError("End of sequence"): - sess.run(padded_value) - - -# TODO(ebrevdo): Figure out how to use run_core_tests to test state -# saving of an iterator that's had some tensors enqueued into its queue. -class PrependFromQueueAndPaddedBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testPrependFromQueueAndPaddedBatch(self): - - def build_dataset(seq_lens): - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - lambda x: array_ops.fill([x], x)).apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=4)) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - def testPrependFromQueueAndPaddedBatchNonDefaultPadding(self): - - def build_dataset(seq_lens): - - def fill_tuple(x): - filled = array_ops.fill([x], x) - return (filled, string_ops.as_string(filled)) - - padded_shape = [-1] - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - fill_tuple).apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, ""))) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index c272a2ac144068cfb7355c2647eebf5bd0ce9d50..093765dc2098d2135a1d86aa44b23c13546267ee 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -244,7 +244,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -354,11 +353,11 @@ def multiply_gradients(grads_and_vars, gradient_multipliers): raise ValueError('Requested multiple of `None` gradient.') if isinstance(grad, ops.IndexedSlices): - tmp = grad.values * constant_op.constant( + tmp = grad.values * ops.convert_to_tensor( gradient_multipliers[key], dtype=grad.dtype) grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape) else: - grad *= constant_op.constant( + grad *= ops.convert_to_tensor( gradient_multipliers[key], dtype=grad.dtype) multiplied_grads_and_vars.append((grad, var)) return multiplied_grads_and_vars @@ -419,7 +418,7 @@ def create_train_op(total_loss, update_ops = set(update_ops) if not global_update_ops.issubset(update_ops): logging.warning('update_ops in create_train_op does not contain all the ' - ' update_ops in GraphKeys.UPDATE_OPS') + 'update_ops in GraphKeys.UPDATE_OPS') # Make sure update_ops are computed before total_loss. if update_ops: diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index d9ccda8e89a4c9a1b3f3d24915b9ad3fb4d9be5f..07dbd5ca8d65ec8232d33c016a7369c68a4c9e1f 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -16,9 +16,12 @@ cc_library( srcs = ["convert_graphdef_memmapped_format_lib.cc"], hdrs = ["convert_graphdef_memmapped_format_lib.h"], deps = [ + "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", "//tensorflow/core/kernels:immutable_constant_op", diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 2784bf124ceaacd8e01f0653287fa7f006d0d608..2f2375427862ad1e99a0e6bfc506382d200e9b1d 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -277,9 +277,18 @@ void RdmaMgr::InitAllocators() { ProcessState::singleton()->AddCPUFreeVisitor(free_visitor); #if GOOGLE_CUDA + GPUProcessState::singleton()->AddCUDAHostAllocVisitor(0, alloc_visitor); + GPUProcessState::singleton()->AddCUDAHostFreeVisitor(0, free_visitor); + if (IsGDRAvailable()) { // Note we don't free allocated GPU memory so there is no free visitor - int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1; + + // TODO: This is to fix the 'invalid use of member in static member function + // bug'. + // Waiting for better implementation. + // int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + // + 1; + int32_t bus_id = 0; SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id, size_t num_bytes) { @@ -288,9 +297,6 @@ void RdmaMgr::InitAllocators() { }; GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id, - alloc_visitor); - GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor); LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; } #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc index 5b72b1604aca2e0c593978c6104322372788eb3c..d07fd5ae6e9cc0dbf67c6b6a4e8db086b4c74aa1 100644 --- a/tensorflow/contrib/verbs/verbs_server_lib.cc +++ b/tensorflow/contrib/verbs/verbs_server_lib.cc @@ -33,6 +33,8 @@ RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env) { return new RdmaRendezvousMgr(env); } +std::once_flag reg_mem_visitors_call; + } // namespace VerbsServer::VerbsServer(const ServerDef& server_def, Env* env) @@ -76,14 +78,13 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, return Status::OK(); } -namespace { -std::once_flag reg_mem_visitors_call; -} // namespace - Status VerbsServer::Init(ServiceInitFunction service_func, RendezvousMgrCreationFunction rendezvous_mgr_func) { std::call_once(reg_mem_visitors_call, []() { RdmaMgr::RegMemVisitors(); }); - Status s = GrpcServer::Init(service_func, rendezvous_mgr_func); + GrpcServerOptions opts; + opts.service_func = service_func; + opts.rendezvous_mgr_func = rendezvous_mgr_func; + Status s = GrpcServer::Init(opts); { mutex_lock l(mu_); CHECK_EQ(verbs_state_, DISCONNECTED); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c268605711fb73f37773ce7b4181bf17f2a3a4fa..3d92a836d1c21845407ec53bd46a24638e158e3b 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -49,7 +49,7 @@ # filegroup ":android_proto_srcs" - Protos # filegroup ":android_srcs" - Core sources # cc_library ":android_tensorflow_lib" - Native library -# cc_library ":android_tensorflow_lib_selective_registration" - Native library +# cc_library ":android_tensorflow_lib_lite" - Native library, without ops, # supporting SELECTIVE_REGISTRATION feature. # portable_proto_library ":android_proto_lib" (Google-internal) # @@ -70,6 +70,9 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 +# Export the BUILD file so automated tooling can check licenses +exports_files(["BUILD"]) + load( "//tensorflow:tensorflow.bzl", "cc_header_only_library", @@ -113,7 +116,6 @@ load( "tf_additional_device_tracer_test_flags", "tf_additional_gdr_lib_defines", "tf_additional_human_readable_json_deps", - "tf_additional_logger_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", "tf_additional_lib_hdrs", @@ -179,7 +181,6 @@ COMMON_PROTO_SRCS = [ "framework/function.proto", "framework/graph.proto", "framework/graph_transfer_info.proto", - "framework/iterator.proto", "framework/kernel_def.proto", "framework/log_memory.proto", "framework/node_def.proto", @@ -200,10 +201,12 @@ COMMON_PROTO_SRCS = [ "protobuf/cluster.proto", "protobuf/debug.proto", "protobuf/device_properties.proto", + "protobuf/graph_debug_info.proto", "protobuf/queue_runner.proto", "protobuf/rewriter_config.proto", "protobuf/tensor_bundle.proto", "protobuf/saver.proto", + "protobuf/verifier_config.proto", "util/event.proto", "util/memmapped_file_system.proto", "util/saved_tensor_slice.proto", @@ -447,14 +450,14 @@ cc_library( cc_library( name = "logger", - srcs = tf_platform_srcs(["logger.cc"]), - hdrs = ["platform/logger.h"] + tf_platform_hdrs(["logger.h"]), + srcs = ["platform/logger.cc"], + hdrs = ["platform/logger.h"], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - ":lib", - ":lib_internal", - ] + tf_additional_logger_deps(), + ":lib_proto_parsing", + "@protobuf_archive//:protobuf", + ], ) filegroup( @@ -492,7 +495,10 @@ cc_library( ":platform_env_internal_hdrs", ], copts = tf_copts(), - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [ + "//tensorflow/c:__subpackages__", + "//tensorflow/core:__subpackages__", + ], deps = [ ":error_codes_proto_cc", ":lib", @@ -502,6 +508,7 @@ cc_library( ":platform_port", ":platform_protobuf", "//tensorflow/core/platform/default/build_config:env", + "//tensorflow/core/platform/default/build_config:port", ], ) @@ -1015,6 +1022,7 @@ cc_library( ":lib", ":lib_internal", ":protos_all_cc", + "//tensorflow/core/util/proto:proto_utils", ], ) @@ -1072,6 +1080,7 @@ tf_gen_op_libs( "tensor_forest_ops", "candidate_sampling_ops", "checkpoint_ops", + "clustering_ops", "collective_ops", "control_flow_ops", "ctc_ops", @@ -1097,6 +1106,7 @@ tf_gen_op_libs( "parsing_ops", "random_grad", "random_ops", + "stateful_random_ops", "remote_fused_graph_ops", "rpc_ops", "scoped_allocator_ops", @@ -1226,6 +1236,7 @@ cc_library( ":tensor_forest_ops_op_lib", ":candidate_sampling_ops_op_lib", ":checkpoint_ops_op_lib", + ":clustering_ops_op_lib", ":collective_ops_op_lib", ":control_flow_ops_op_lib", ":ctc_ops_op_lib", @@ -1251,6 +1262,7 @@ cc_library( ":parsing_ops_op_lib", ":ragged_ops", ":random_ops_op_lib", + ":stateful_random_ops_op_lib", ":remote_fused_graph_ops_op_lib", ":resource_variable_ops_op_lib", ":rpc_ops_op_lib", @@ -1369,7 +1381,7 @@ cc_library( # This includes implementations of all kernels built into TensorFlow. cc_library( - name = "all_kernels_statically_linked", + name = "all_kernels_impl", visibility = ["//visibility:private"], deps = [ "//tensorflow/core/kernels:array", @@ -1380,12 +1392,12 @@ cc_library( "//tensorflow/core/kernels:tensor_forest_ops", "//tensorflow/core/kernels:candidate_sampler_ops", "//tensorflow/core/kernels:checkpoint_ops", + "//tensorflow/core/kernels:clustering_ops", "//tensorflow/core/kernels:collective_ops", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:ctc_ops", "//tensorflow/core/kernels:cudnn_rnn_kernels", "//tensorflow/core/kernels:data_flow", - "//tensorflow/core/kernels:dataset_ops", "//tensorflow/core/kernels:decode_proto_op", "//tensorflow/core/kernels:encode_proto_op", "//tensorflow/core/kernels:fake_quant_ops", @@ -1396,18 +1408,20 @@ cc_library( "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", "//tensorflow/core/kernels:linalg", - "//tensorflow/core/kernels:list_kernels", "//tensorflow/core/kernels:lookup", "//tensorflow/core/kernels:logging", "//tensorflow/core/kernels:manip", "//tensorflow/core/kernels:math", "//tensorflow/core/kernels:multinomial_op", + "//tensorflow/core/kernels:mutex_ops", "//tensorflow/core/kernels:nn", "//tensorflow/core/kernels:parameterized_truncated_normal_op", "//tensorflow/core/kernels:parsing", "//tensorflow/core/kernels:partitioned_function_ops", + "//tensorflow/core/kernels:pooling_ops", "//tensorflow/core/kernels:ragged_ops", "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:stateful_random_ops", "//tensorflow/core/kernels:random_poisson_op", "//tensorflow/core/kernels:remote_fused_graph_ops", "//tensorflow/core/kernels:required", @@ -1459,8 +1473,13 @@ cc_library( visibility = ["//visibility:public"], deps = if_dynamic_kernels( [], - otherwise = [":all_kernels_statically_linked"], - ), + otherwise = [":all_kernels_impl"], + ) + [ + # TODO(gunan): Work on the API between these and rest of TF and make + # these also dynamically loading. + "//tensorflow/core/kernels:dataset_ops", # Depends on grappler + "//tensorflow/core/kernels:list_kernels", # Depends on variant_op_registry.h + ], ) tf_cuda_library( @@ -1608,6 +1627,9 @@ filegroup( "**/*main.cc", "debug/**/*", "framework/op_gen_*", + "framework/node_def_util.*", + "framework/op_kernel.*", + "framework/dataset.*", "lib/jpeg/**/*", "lib/png/**/*", "lib/gif/**/*", @@ -1616,7 +1638,6 @@ filegroup( "util/reporter.*", "platform/**/cuda_libdevice_path.*", "platform/**/logger.cc", - "platform/**/logger.h", "platform/default/test_benchmark.*", "platform/cuda.h", "platform/google/**/*", @@ -1651,6 +1672,9 @@ filegroup( "common_runtime/**/*.cc", "graph/**/*.h", "graph/**/*.cc", + "framework/node_def_util.*", + "framework/op_kernel.*", + "framework/dataset.*", ], exclude = [ "**/*test.*", @@ -1679,6 +1703,9 @@ filegroup( # operators, use :android_tensorflow_lib if you want full operator # support. # +# If you just need TensorFlow types, e.g. Tensors, use +# :android_tensorflow_lib_lite_no_runtime. +# # Compiles to a trivial library on non-Android to prevent irrelevant # build errors. If not building this as part of an android_binary, # a command such as the following must be used: @@ -1689,7 +1716,33 @@ filegroup( cc_library( name = "android_tensorflow_lib_lite", srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None), + copts = tf_copts(android_optimization_level_override = None) + [ + "-DSUPPORT_SELECTIVE_REGISTRATION", + ], + linkopts = ["-lz"], + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], + deps = [ + ":mobile_additional_lib_deps", + ":protos_all_cc_impl", + ":stats_calculator_portable", + "//third_party/eigen3", + "@double_conversion//:double-conversion", + "@nsync//:nsync_cpp", + "@protobuf_archive//:protobuf", + ], + alwayslink = 1, +) + +cc_library( + name = "android_tensorflow_lib_lite_nortti", + srcs = if_android(["//tensorflow/core:android_srcs"]), + copts = tf_copts(android_optimization_level_override = None) + [ + "-DSUPPORT_SELECTIVE_REGISTRATION", + ] + tf_opts_nortti_if_android(), linkopts = ["-lz"], tags = [ "manual", @@ -1711,6 +1764,7 @@ cc_library( cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], @@ -1797,58 +1851,6 @@ cc_library( alwayslink = 1, ) -# Android library for use with the SELECTIVE_REGISTRATION feature. -# Does not contain operators. In contrast to android_tensorflow_lib_lite, -# this links in framework support for all types, relying on selective -# registration of ops to prune code size. -cc_library( - name = "android_tensorflow_lib_selective_registration", - srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None) + [ - "-DSUPPORT_SELECTIVE_REGISTRATION", - ], - linkopts = if_android(["-lz"]), - tags = [ - "manual", - "notap", - ], - visibility = ["//visibility:public"], - deps = [ - ":protos_all_cc_impl", - "//third_party/eigen3", - "@com_google_absl//absl/container:flat_hash_set", - "@double_conversion//:double-conversion", - "@nsync//:nsync_cpp", - "@protobuf_archive//:protobuf", - ], - alwayslink = 1, -) - -# Android library for use with the SELECTIVE_REGISTRATION feature with -# no proto_rtti. -cc_library( - name = "android_tensorflow_lib_selective_registration_nortti", - srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ - "-DSUPPORT_SELECTIVE_REGISTRATION", - ], - linkopts = if_android(["-lz"]), - tags = [ - "manual", - "notap", - ], - visibility = ["//visibility:public"], - deps = [ - ":protos_all_cc_impl", - "//third_party/eigen3", - "@com_google_absl//absl/container:flat_hash_set", - "@double_conversion//:double-conversion", - "@nsync//:nsync_cpp", - "@protobuf_archive//:protobuf", - ], - alwayslink = 1, -) - filegroup( name = "android_op_registrations_and_gradients", srcs = glob( @@ -1963,6 +1965,14 @@ cc_library( ], ) +cc_library( + name = "rocm", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/platform/default/build_config:rocm", + ], +) + # ----------------------------------------------------------------------------- # Clif-related proto libraries. @@ -2022,6 +2032,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "framework/step_stats_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "framework/step_stats.proto", + visibility = ["//visibility:public"], +) + tf_pyclif_proto_library( name = "framework/types_pyclif", proto_lib = ":protos_all_cc", @@ -2087,9 +2104,7 @@ tf_proto_library_cc( srcs = ["protobuf/master.proto"], cc_api_version = 2, protodeps = tf_additional_all_protos(), - visibility = [ - "//tensorflow:internal", - ], + visibility = ["//tensorflow:internal"], ) tf_proto_library_cc( @@ -2201,6 +2216,7 @@ cc_library( ], }), deps = tf_additional_lib_deps() + [ + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "//third_party/eigen3", "@com_google_absl//absl/base:core_headers", @@ -2215,7 +2231,6 @@ cc_library( "lib/**/*.cc", "platform/*.cc", "platform/profile_utils/**/*.cc", - ] + [ "framework/resource_handle.cc", "util/env_var.cc", ], @@ -2355,7 +2370,12 @@ cc_library( cc_library( name = "tflite_portable_logging", - srcs = [], + srcs = [ + ] + if_ios([ + "platform/default/logging.cc", + "platform/env_time.cc", + "platform/posix/env_time.cc", + ]), hdrs = [ "lib/bfloat16/bfloat16.h", "platform/default/integral_types.h", @@ -2364,7 +2384,7 @@ cc_library( "platform/macros.h", "platform/platform.h", "platform/types.h", - ] + if_windows(["platform/windows/integral_types.h"]), + ] + if_windows(["platform/windows/integral_types.h"]) + if_ios(["platform/env_time.h"]), copts = tf_copts(), linkopts = ["-ldl"], deps = [ @@ -2774,6 +2794,7 @@ cc_library( # in this library. GRAPH_HDRS = [ "graph/algorithm.h", + "graph/collective_order.h", "graph/colors.h", "graph/control_flow.h", "graph/costmodel.h", @@ -2800,6 +2821,7 @@ tf_cuda_library( name = "graph", srcs = [ "graph/algorithm.cc", + "graph/collective_order.cc", "graph/colors.cc", "graph/control_flow.cc", "graph/costmodel.cc", @@ -2817,6 +2839,9 @@ tf_cuda_library( ":proto_text", ":protos_all_cc", "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) @@ -2831,12 +2856,16 @@ CORE_CPU_BASE_HDRS = GRAPH_HDRS + [ "framework/versions.h", "common_runtime/process_function_library_runtime.h", "common_runtime/function.h", + "common_runtime/scoped_allocator.h", + "common_runtime/scoped_allocator_mgr.h", ] tf_cuda_library( name = "core_cpu_base", srcs = [ "common_runtime/eval_const_tensor.cc", + "common_runtime/scoped_allocator.cc", + "common_runtime/scoped_allocator_mgr.cc", "common_runtime/shape_refiner.cc", "common_runtime/shape_refiner.h", "framework/versions.h", @@ -2896,6 +2925,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/mkl_cpu_allocator.h", "common_runtime/optimization_registry.h", "common_runtime/pending_counts.h", + "common_runtime/partitioning_utils.h", "common_runtime/placer.h", "common_runtime/process_util.h", "common_runtime/profile_handler.h", @@ -2903,8 +2933,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/rendezvous_mgr.h", "common_runtime/rendezvous_util.h", "common_runtime/ring_reducer.h", - "common_runtime/scoped_allocator.h", - "common_runtime/scoped_allocator_mgr.h", "common_runtime/session_factory.h", "common_runtime/single_threaded_cpu_device.h", "common_runtime/stats_publisher_interface.h", @@ -2952,6 +2980,7 @@ tf_cuda_library( "common_runtime/mkl_cpu_allocator.cc", "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", + "common_runtime/partitioning_utils.cc", "common_runtime/placer.cc", "common_runtime/pool_allocator.cc", "common_runtime/process_function_library_runtime.cc", @@ -2961,8 +2990,6 @@ tf_cuda_library( "common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_util.cc", "common_runtime/ring_reducer.cc", - "common_runtime/scoped_allocator.cc", - "common_runtime/scoped_allocator_mgr.cc", "common_runtime/session.cc", "common_runtime/session_factory.cc", "common_runtime/session_options.cc", @@ -2990,8 +3017,9 @@ tf_cuda_library( ":proto_text", ":protos_all_cc", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "//third_party/eigen3", - "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/utils:functions", ] + mkl_deps(), alwayslink = 1, ) @@ -3019,6 +3047,7 @@ tf_cuda_library( ":framework", ":graph", ":lib", + ":metrics", ":proto_text", ":protos_all_cc", "//tensorflow/core/grappler:grappler_item", @@ -3506,6 +3535,29 @@ tf_cc_test( ], ) +tf_cc_test( + name = "platform_fake_python_env_test", + size = "small", + srcs = ["platform/fake_python_env_test.cc"], + args = [ + "/some/path/to/pythontest.runfiles/org_tensorflow/stuff/to/run.py", + ], + tags = [ + "local", + "no_windows", + "nogpu", + "nomac", + "notap", + ], + deps = [ + ":lib", + ":lib_internal", + ":lib_test_internal", + ":test", + ":test_main", + ], +) + tf_cc_test( name = "platform_abi_test", size = "small", @@ -3679,7 +3731,6 @@ tf_cc_tests( srcs = [ "common_runtime/buf_rendezvous_test.cc", "common_runtime/collective_executor_mgr_test.cc", - "common_runtime/collective_param_resolver_local_test.cc", "common_runtime/collective_rma_local_test.cc", "common_runtime/device_resolver_local_test.cc", "common_runtime/device_set_test.cc", @@ -3795,6 +3846,7 @@ tf_cc_tests( name = "higher_level_tests_needing_kernels", size = "small", srcs = [ + "common_runtime/collective_param_resolver_local_test.cc", "graph/graph_constructor_test.cc", ], linkopts = select({ @@ -3844,6 +3896,27 @@ tf_cc_test( ], ) +tf_cc_tests( + name = "collective_order_test", + size = "small", + srcs = [ + "graph/collective_order_test.cc", + ], + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":test", + "@com_google_googletest//:gtest_main", + ], +) + tf_cc_tests_gpu( name = "ring_reducer_test", size = "medium", @@ -4060,20 +4133,6 @@ tf_cuda_cc_test( ], ) -tf_cc_test_gpu( - name = "cuda_libdevice_path_test", - size = "small", - srcs = ["platform/cuda_libdevice_path_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":cuda_libdevice_path", - ":lib", - ":test", - ":test_main", - ], -) - tf_cuda_only_cc_test( name = "util_cuda_kernel_helper_test", srcs = [ @@ -4207,7 +4266,7 @@ tf_cc_test( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "common_runtime_process_function_library_runtime_test", size = "small", srcs = ["common_runtime/process_function_library_runtime_test.cc"], @@ -4216,6 +4275,7 @@ tf_cc_test( ":core_cpu", ":core_cpu_internal", ":framework", + ":framework_internal", ":lib", ":test", ":test_main", @@ -4224,6 +4284,7 @@ tf_cc_test( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:resource_variable_ops", ], ) @@ -4265,6 +4326,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "common_runtime_partitioning_utils_test", + size = "small", + srcs = ["common_runtime/partitioning_utils_test.cc"], + deps = [ + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":lib", + ":ops", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:identity_op", + ], +) + tf_cuda_cc_test( name = "common_runtime_direct_session_test", size = "small", @@ -4929,7 +5011,7 @@ filegroup( cc_library( name = "cuda_libdevice_path", - srcs = ["platform/cuda_libdevice_path.cc"] + tf_additional_libdevice_srcs(), + srcs = tf_additional_libdevice_srcs(), hdrs = ["platform/cuda_libdevice_path.h"], copts = tf_copts(), data = tf_additional_libdevice_data(), diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index d38a8424eb13009fbf84d7511fb1325085d8b809..7405e2ace72d1c08cf87cc0040e617379e18149b 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt index 639d962874d083472e6df13550e107026fd2d0a1..32def912f83e420eab58a3071f573ae81139a298 100644 --- a/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "BatchDataset" + visibility: HIDDEN in_arg { name: "batch_size" description: <